diff --git a/adapter/outbound/tuic.go b/adapter/outbound/tuic.go index ab371757..e2aafca5 100644 --- a/adapter/outbound/tuic.go +++ b/adapter/outbound/tuic.go @@ -13,13 +13,14 @@ import ( "strconv" "time" - "github.com/metacubex/quic-go" - "github.com/Dreamacro/clash/component/dialer" "github.com/Dreamacro/clash/component/proxydialer" tlsC "github.com/Dreamacro/clash/component/tls" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/tuic" + + "github.com/gofrs/uuid/v5" + "github.com/metacubex/quic-go" ) type Tuic struct { @@ -33,7 +34,9 @@ type TuicOption struct { Name string `proxy:"name"` Server string `proxy:"server"` Port int `proxy:"port"` - Token string `proxy:"token"` + Token string `proxy:"token,omitempty"` + UUID string `proxy:"uuid,omitempty"` + Password string `proxy:"password,omitempty"` Ip string `proxy:"ip,omitempty"` HeartbeatInterval int `proxy:"heartbeat-interval,omitempty"` ALPN []string `proxy:"alpn,omitempty"` @@ -184,14 +187,19 @@ func NewTuic(option TuicOption) (*Tuic, error) { option.MaxOpenStreams = 100 } + packetOverHead := tuic.PacketOverHeadV4 + if len(option.Token) == 0 { + packetOverHead = tuic.PacketOverHeadV5 + } + if option.MaxDatagramFrameSize == 0 { - option.MaxDatagramFrameSize = option.MaxUdpRelayPacketSize + tuic.PacketOverHead + option.MaxDatagramFrameSize = option.MaxUdpRelayPacketSize + packetOverHead } if option.MaxDatagramFrameSize > 1400 { option.MaxDatagramFrameSize = 1400 } - option.MaxUdpRelayPacketSize = option.MaxDatagramFrameSize - tuic.PacketOverHead + option.MaxUdpRelayPacketSize = option.MaxDatagramFrameSize - packetOverHead // ensure server's incoming stream can handle correctly, increase to 1.1x quicMaxOpenStreams := int64(option.MaxOpenStreams) @@ -222,8 +230,8 @@ func NewTuic(option TuicOption) (*Tuic, error) { } if option.DisableSni { tlsConfig.ServerName = "" + tlsConfig.InsecureSkipVerify = true // tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config } - tkn := tuic.GenTKN(option.Token) t := &Tuic{ Base: &Base{ @@ -249,20 +257,38 @@ func NewTuic(option TuicOption) (*Tuic, error) { if clientMaxOpenStreams < 1 { clientMaxOpenStreams = 1 } - clientOption := &tuic.ClientOption{ - TlsConfig: tlsConfig, - QuicConfig: quicConfig, - Token: tkn, - UdpRelayMode: option.UdpRelayMode, - CongestionController: option.CongestionController, - ReduceRtt: option.ReduceRtt, - RequestTimeout: time.Duration(option.RequestTimeout) * time.Millisecond, - MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, - FastOpen: option.FastOpen, - MaxOpenStreams: clientMaxOpenStreams, - } - t.client = tuic.NewPoolClient(clientOption) + if len(option.Token) > 0 { + tkn := tuic.GenTKN(option.Token) + clientOption := &tuic.ClientOptionV4{ + TlsConfig: tlsConfig, + QuicConfig: quicConfig, + Token: tkn, + UdpRelayMode: option.UdpRelayMode, + CongestionController: option.CongestionController, + ReduceRtt: option.ReduceRtt, + RequestTimeout: time.Duration(option.RequestTimeout) * time.Millisecond, + MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, + FastOpen: option.FastOpen, + MaxOpenStreams: clientMaxOpenStreams, + } + + t.client = tuic.NewPoolClientV4(clientOption) + } else { + clientOption := &tuic.ClientOptionV5{ + TlsConfig: tlsConfig, + QuicConfig: quicConfig, + Uuid: uuid.FromStringOrNil(option.UUID), + Password: option.Password, + UdpRelayMode: option.UdpRelayMode, + CongestionController: option.CongestionController, + ReduceRtt: option.ReduceRtt, + MaxUdpRelayPacketSize: option.MaxUdpRelayPacketSize, + MaxOpenStreams: clientMaxOpenStreams, + } + + t.client = tuic.NewPoolClientV5(clientOption) + } return t, nil } diff --git a/config/config.go b/config/config.go index 0a263a08..0e161ddd 100644 --- a/config/config.go +++ b/config/config.go @@ -219,16 +219,17 @@ type RawTun struct { } type RawTuicServer struct { - Enable bool `yaml:"enable" json:"enable"` - Listen string `yaml:"listen" json:"listen"` - Token []string `yaml:"token" json:"token"` - Certificate string `yaml:"certificate" json:"certificate"` - PrivateKey string `yaml:"private-key" json:"private-key"` - CongestionController string `yaml:"congestion-controller" json:"congestion-controller,omitempty"` - MaxIdleTime int `yaml:"max-idle-time" json:"max-idle-time,omitempty"` - AuthenticationTimeout int `yaml:"authentication-timeout" json:"authentication-timeout,omitempty"` - ALPN []string `yaml:"alpn" json:"alpn,omitempty"` - MaxUdpRelayPacketSize int `yaml:"max-udp-relay-packet-size" json:"max-udp-relay-packet-size,omitempty"` + Enable bool `yaml:"enable" json:"enable"` + Listen string `yaml:"listen" json:"listen"` + Token []string `yaml:"token" json:"token"` + Users map[string]string `yaml:"users" json:"users,omitempty"` + Certificate string `yaml:"certificate" json:"certificate"` + PrivateKey string `yaml:"private-key" json:"private-key"` + CongestionController string `yaml:"congestion-controller" json:"congestion-controller,omitempty"` + MaxIdleTime int `yaml:"max-idle-time" json:"max-idle-time,omitempty"` + AuthenticationTimeout int `yaml:"authentication-timeout" json:"authentication-timeout,omitempty"` + ALPN []string `yaml:"alpn" json:"alpn,omitempty"` + MaxUdpRelayPacketSize int `yaml:"max-udp-relay-packet-size" json:"max-udp-relay-packet-size,omitempty"` } type RawConfig struct { @@ -355,6 +356,7 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) { TuicServer: RawTuicServer{ Enable: false, Token: nil, + Users: nil, Certificate: "", PrivateKey: "", Listen: "", @@ -1294,6 +1296,7 @@ func parseTuicServer(rawTuic RawTuicServer, general *General) error { Enable: rawTuic.Enable, Listen: rawTuic.Listen, Token: rawTuic.Token, + Users: rawTuic.Users, Certificate: rawTuic.Certificate, PrivateKey: rawTuic.PrivateKey, CongestionController: rawTuic.CongestionController, diff --git a/docs/config.yaml b/docs/config.yaml index 311cf50f..a8207917 100644 --- a/docs/config.yaml +++ b/docs/config.yaml @@ -661,7 +661,11 @@ proxies: # socks5 server: www.example.com port: 10443 type: tuic + # tuicV4必须填写token (不可同时填写uuid和password) token: TOKEN + # tuicV5必须填写uuid和password(不可同时填写token) + uuid: 00000000-0000-0000-0000-000000000001 + password: PASSWORD_1 # ip: 127.0.0.1 # for overwriting the DNS lookup result of the server address set in option 'server' # heartbeat-interval: 10000 # alpn: [h3] @@ -899,8 +903,11 @@ listeners: listen: 0.0.0.0 # rule: sub-rule-name1 # 默认使用 rules,如果未找到 sub-rule 则直接使用 rules # proxy: proxy # 如果不为空则直接将该入站流量交由指定proxy处理(当proxy不为空时,这里的proxy名称必须合法,否则会出错) - # token: - # - TOKEN + # token: # tuicV4填写(不可同时填写users) + # - TOKEN + # users: # tuicV5填写(不可同时填写token) + # 00000000-0000-0000-0000-000000000000: PASSWORD_0 + # 00000000-0000-0000-0000-000000000001: PASSWORD_1 # certificate: ./server.crt # private-key: ./server.key # congestion-controller: bbr @@ -970,8 +977,11 @@ listeners: # tuic-server: # enable: true # listen: 127.0.0.1:10443 -# token: +# token: # tuicV4填写(不可同时填写users) # - TOKEN +# users: # tuicV5填写(不可同时填写token) +# 00000000-0000-0000-0000-000000000000: PASSWORD_0 +# 00000000-0000-0000-0000-000000000001: PASSWORD_1 # certificate: ./server.crt # private-key: ./server.key # congestion-controller: bbr diff --git a/hub/route/configs.go b/hub/route/configs.go index afafe80e..a8c24f90 100644 --- a/hub/route/configs.go +++ b/hub/route/configs.go @@ -84,16 +84,17 @@ type tunSchema struct { } type tuicServerSchema struct { - Enable bool `yaml:"enable" json:"enable"` - Listen *string `yaml:"listen" json:"listen"` - Token *[]string `yaml:"token" json:"token"` - Certificate *string `yaml:"certificate" json:"certificate"` - PrivateKey *string `yaml:"private-key" json:"private-key"` - CongestionController *string `yaml:"congestion-controller" json:"congestion-controller,omitempty"` - MaxIdleTime *int `yaml:"max-idle-time" json:"max-idle-time,omitempty"` - AuthenticationTimeout *int `yaml:"authentication-timeout" json:"authentication-timeout,omitempty"` - ALPN *[]string `yaml:"alpn" json:"alpn,omitempty"` - MaxUdpRelayPacketSize *int `yaml:"max-udp-relay-packet-size" json:"max-udp-relay-packet-size,omitempty"` + Enable bool `yaml:"enable" json:"enable"` + Listen *string `yaml:"listen" json:"listen"` + Token *[]string `yaml:"token" json:"token"` + Users *map[string]string `yaml:"users" json:"users,omitempty"` + Certificate *string `yaml:"certificate" json:"certificate"` + PrivateKey *string `yaml:"private-key" json:"private-key"` + CongestionController *string `yaml:"congestion-controller" json:"congestion-controller,omitempty"` + MaxIdleTime *int `yaml:"max-idle-time" json:"max-idle-time,omitempty"` + AuthenticationTimeout *int `yaml:"authentication-timeout" json:"authentication-timeout,omitempty"` + ALPN *[]string `yaml:"alpn" json:"alpn,omitempty"` + MaxUdpRelayPacketSize *int `yaml:"max-udp-relay-packet-size" json:"max-udp-relay-packet-size,omitempty"` } func getConfigs(w http.ResponseWriter, r *http.Request) { @@ -186,6 +187,9 @@ func pointerOrDefaultTuicServer(p *tuicServerSchema, def LC.TuicServer) LC.TuicS if p.Token != nil { def.Token = *p.Token } + if p.Users != nil { + def.Users = *p.Users + } if p.Certificate != nil { def.Certificate = *p.Certificate } diff --git a/listener/config/tuic.go b/listener/config/tuic.go index 991a04c9..30c99054 100644 --- a/listener/config/tuic.go +++ b/listener/config/tuic.go @@ -5,17 +5,18 @@ import ( ) type TuicServer struct { - Enable bool `yaml:"enable" json:"enable"` - Listen string `yaml:"listen" json:"listen"` - Token []string `yaml:"token" json:"token"` - Certificate string `yaml:"certificate" json:"certificate"` - PrivateKey string `yaml:"private-key" json:"private-key"` - CongestionController string `yaml:"congestion-controller" json:"congestion-controller,omitempty"` - MaxIdleTime int `yaml:"max-idle-time" json:"max-idle-time,omitempty"` - AuthenticationTimeout int `yaml:"authentication-timeout" json:"authentication-timeout,omitempty"` - ALPN []string `yaml:"alpn" json:"alpn,omitempty"` - MaxUdpRelayPacketSize int `yaml:"max-udp-relay-packet-size" json:"max-udp-relay-packet-size,omitempty"` - MaxDatagramFrameSize int `yaml:"max-datagram-frame-size" json:"max-datagram-frame-size,omitempty"` + Enable bool `yaml:"enable" json:"enable"` + Listen string `yaml:"listen" json:"listen"` + Token []string `yaml:"token" json:"token,omitempty"` + Users map[string]string `yaml:"users" json:"users,omitempty"` + Certificate string `yaml:"certificate" json:"certificate"` + PrivateKey string `yaml:"private-key" json:"private-key"` + CongestionController string `yaml:"congestion-controller" json:"congestion-controller,omitempty"` + MaxIdleTime int `yaml:"max-idle-time" json:"max-idle-time,omitempty"` + AuthenticationTimeout int `yaml:"authentication-timeout" json:"authentication-timeout,omitempty"` + ALPN []string `yaml:"alpn" json:"alpn,omitempty"` + MaxUdpRelayPacketSize int `yaml:"max-udp-relay-packet-size" json:"max-udp-relay-packet-size,omitempty"` + MaxDatagramFrameSize int `yaml:"max-datagram-frame-size" json:"max-datagram-frame-size,omitempty"` } func (t TuicServer) String() string { diff --git a/listener/inbound/tuic.go b/listener/inbound/tuic.go index f6641500..2e234e2d 100644 --- a/listener/inbound/tuic.go +++ b/listener/inbound/tuic.go @@ -9,14 +9,15 @@ import ( type TuicOption struct { BaseOption - Token []string `inbound:"token"` - Certificate string `inbound:"certificate"` - PrivateKey string `inbound:"private-key"` - CongestionController string `inbound:"congestion-controller,omitempty"` - MaxIdleTime int `inbound:"max-idle-time,omitempty"` - AuthenticationTimeout int `inbound:"authentication-timeout,omitempty"` - ALPN []string `inbound:"alpn,omitempty"` - MaxUdpRelayPacketSize int `inbound:"max-udp-relay-packet-size,omitempty"` + Token []string `inbound:"token,omitempty"` + Users map[string]string `inbound:"users,omitempty"` + Certificate string `inbound:"certificate"` + PrivateKey string `inbound:"private-key"` + CongestionController string `inbound:"congestion-controller,omitempty"` + MaxIdleTime int `inbound:"max-idle-time,omitempty"` + AuthenticationTimeout int `inbound:"authentication-timeout,omitempty"` + ALPN []string `inbound:"alpn,omitempty"` + MaxUdpRelayPacketSize int `inbound:"max-udp-relay-packet-size,omitempty"` } func (o TuicOption) Equal(config C.InboundConfig) bool { @@ -42,6 +43,7 @@ func NewTuic(options *TuicOption) (*Tuic, error) { Enable: true, Listen: base.RawAddress(), Token: options.Token, + Users: options.Users, Certificate: options.Certificate, PrivateKey: options.PrivateKey, CongestionController: options.CongestionController, diff --git a/listener/tuic/server.go b/listener/tuic/server.go index 498708bf..e1d8175c 100644 --- a/listener/tuic/server.go +++ b/listener/tuic/server.go @@ -6,8 +6,6 @@ import ( "strings" "time" - "github.com/metacubex/quic-go" - "github.com/Dreamacro/clash/adapter/inbound" CN "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/sockopt" @@ -16,6 +14,10 @@ import ( "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/tuic" + + "github.com/gofrs/uuid/v5" + "github.com/metacubex/quic-go" + "golang.org/x/exp/slices" ) const ServerMaxIncomingStreams = (1 << 32) - 1 @@ -24,7 +26,7 @@ type Listener struct { closed bool config LC.TuicServer udpListeners []net.PacketConn - servers []*tuic.Server + servers []tuic.Server } func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.PacketAdapter, additions ...inbound.Addition) (*Listener, error) { @@ -59,39 +61,77 @@ func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.Packet quicConfig.InitialConnectionReceiveWindow = tuic.DefaultConnectionReceiveWindow / 10 quicConfig.MaxConnectionReceiveWindow = tuic.DefaultConnectionReceiveWindow + packetOverHead := tuic.PacketOverHeadV4 + if len(config.Token) == 0 { + packetOverHead = tuic.PacketOverHeadV5 + } + if config.MaxUdpRelayPacketSize == 0 { config.MaxUdpRelayPacketSize = 1500 } - maxDatagramFrameSize := config.MaxUdpRelayPacketSize + tuic.PacketOverHead + maxDatagramFrameSize := config.MaxUdpRelayPacketSize + packetOverHead if maxDatagramFrameSize > 1400 { maxDatagramFrameSize = 1400 } - config.MaxUdpRelayPacketSize = maxDatagramFrameSize - tuic.PacketOverHead + config.MaxUdpRelayPacketSize = maxDatagramFrameSize - packetOverHead quicConfig.MaxDatagramFrameSize = int64(maxDatagramFrameSize) - tokens := make([][32]byte, len(config.Token)) - for i, token := range config.Token { - tokens[i] = tuic.GenTKN(token) + handleTcpFn := func(conn net.Conn, addr socks5.Addr, _additions ...inbound.Addition) error { + newAdditions := additions + if len(_additions) > 0 { + newAdditions = slices.Clone(additions) + newAdditions = append(newAdditions, _additions...) + } + tcpIn <- inbound.NewSocket(addr, conn, C.TUIC, newAdditions...) + return nil + } + handleUdpFn := func(addr socks5.Addr, packet C.UDPPacket, _additions ...inbound.Addition) error { + newAdditions := additions + if len(_additions) > 0 { + newAdditions = slices.Clone(additions) + newAdditions = append(newAdditions, _additions...) + } + select { + case udpIn <- inbound.NewPacket(addr, packet, C.TUIC, newAdditions...): + default: + } + return nil } - option := &tuic.ServerOption{ - HandleTcpFn: func(conn net.Conn, addr socks5.Addr) error { - tcpIn <- inbound.NewSocket(addr, conn, C.TUIC, additions...) - return nil - }, - HandleUdpFn: func(addr socks5.Addr, packet C.UDPPacket) error { - select { - case udpIn <- inbound.NewPacket(addr, packet, C.TUIC, additions...): - default: - } - return nil - }, - TlsConfig: tlsConfig, - QuicConfig: quicConfig, - Tokens: tokens, - CongestionController: config.CongestionController, - AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond, - MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize, + var optionV4 *tuic.ServerOptionV4 + var optionV5 *tuic.ServerOptionV5 + if len(config.Token) > 0 { + tokens := make([][32]byte, len(config.Token)) + for i, token := range config.Token { + tokens[i] = tuic.GenTKN(token) + } + + optionV4 = &tuic.ServerOptionV4{ + HandleTcpFn: handleTcpFn, + HandleUdpFn: handleUdpFn, + TlsConfig: tlsConfig, + QuicConfig: quicConfig, + Tokens: tokens, + CongestionController: config.CongestionController, + AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond, + MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize, + } + } else { + users := make(map[[16]byte]string) + for _uuid, password := range config.Users { + users[uuid.FromStringOrNil(_uuid)] = password + } + + optionV5 = &tuic.ServerOptionV5{ + HandleTcpFn: handleTcpFn, + HandleUdpFn: handleUdpFn, + TlsConfig: tlsConfig, + QuicConfig: quicConfig, + Users: users, + CongestionController: config.CongestionController, + AuthenticationTimeout: time.Duration(config.AuthenticationTimeout) * time.Millisecond, + MaxUdpRelayPacketSize: config.MaxUdpRelayPacketSize, + } } sl := &Listener{false, config, nil, nil} @@ -111,7 +151,12 @@ func New(config LC.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- C.Packet sl.udpListeners = append(sl.udpListeners, ul) - server, err := tuic.NewServer(option, ul) + var server tuic.Server + if optionV4 != nil { + server, err = tuic.NewServerV4(optionV4, ul) + } else { + server, err = tuic.NewServerV5(optionV5, ul) + } if err != nil { return nil, err } diff --git a/transport/tuic/common/congestion.go b/transport/tuic/common/congestion.go new file mode 100644 index 00000000..e2f7d867 --- /dev/null +++ b/transport/tuic/common/congestion.go @@ -0,0 +1,44 @@ +package common + +import ( + "github.com/Dreamacro/clash/transport/tuic/congestion" + + "github.com/metacubex/quic-go" +) + +const ( + DefaultStreamReceiveWindow = 15728640 // 15 MB/s + DefaultConnectionReceiveWindow = 67108864 // 64 MB/s +) + +func SetCongestionController(quicConn quic.Connection, cc string) { + switch cc { + case "cubic": + quicConn.SetCongestionControl( + congestion.NewCubicSender( + congestion.DefaultClock{}, + congestion.GetInitialPacketSize(quicConn.RemoteAddr()), + false, + nil, + ), + ) + case "new_reno": + quicConn.SetCongestionControl( + congestion.NewCubicSender( + congestion.DefaultClock{}, + congestion.GetInitialPacketSize(quicConn.RemoteAddr()), + true, + nil, + ), + ) + case "bbr": + quicConn.SetCongestionControl( + congestion.NewBBRSender( + congestion.DefaultClock{}, + congestion.GetInitialPacketSize(quicConn.RemoteAddr()), + congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize, + congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize, + ), + ) + } +} diff --git a/transport/tuic/common/stream.go b/transport/tuic/common/stream.go new file mode 100644 index 00000000..e65f9a49 --- /dev/null +++ b/transport/tuic/common/stream.go @@ -0,0 +1,67 @@ +package common + +import ( + "net" + "sync" + "time" + + "github.com/metacubex/quic-go" +) + +type quicStreamConn struct { + quic.Stream + lock sync.Mutex + lAddr net.Addr + rAddr net.Addr + + closeDeferFn func() + + closeOnce sync.Once + closeErr error +} + +func (q *quicStreamConn) Write(p []byte) (n int, err error) { + q.lock.Lock() + defer q.lock.Unlock() + return q.Stream.Write(p) +} + +func (q *quicStreamConn) Close() error { + q.closeOnce.Do(func() { + q.closeErr = q.close() + }) + return q.closeErr +} + +func (q *quicStreamConn) close() error { + if q.closeDeferFn != nil { + defer q.closeDeferFn() + } + + // https://github.com/cloudflare/cloudflared/commit/ed2bac026db46b239699ac5ce4fcf122d7cab2cd + // Make sure a possible writer does not block the lock forever. We need it, so we can close the writer + // side of the stream safely. + _ = q.Stream.SetWriteDeadline(time.Now()) + + // This lock is eventually acquired despite Write also acquiring it, because we set a deadline to writes. + q.lock.Lock() + defer q.lock.Unlock() + + // We have to clean up the receiving stream ourselves since the Close in the bottom does not handle that. + q.Stream.CancelRead(0) + return q.Stream.Close() +} + +func (q *quicStreamConn) LocalAddr() net.Addr { + return q.lAddr +} + +func (q *quicStreamConn) RemoteAddr() net.Addr { + return q.rAddr +} + +var _ net.Conn = (*quicStreamConn)(nil) + +func NewQuicStreamConn(stream quic.Stream, lAddr, rAddr net.Addr, closeDeferFn func()) net.Conn { + return &quicStreamConn{Stream: stream, lAddr: lAddr, rAddr: rAddr, closeDeferFn: closeDeferFn} +} diff --git a/transport/tuic/common/type.go b/transport/tuic/common/type.go new file mode 100644 index 00000000..16c6f49e --- /dev/null +++ b/transport/tuic/common/type.go @@ -0,0 +1,34 @@ +package common + +import ( + "context" + "errors" + "net" + "time" + + C "github.com/Dreamacro/clash/constant" + + "github.com/metacubex/quic-go" +) + +var ( + ClientClosed = errors.New("tuic: client closed") + TooManyOpenStreams = errors.New("tuic: too many open streams") +) + +type DialFunc func(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) + +type Client interface { + DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) + ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) + OpenStreams() int64 + DialerRef() C.Dialer + LastVisited() time.Time + SetLastVisited(last time.Time) + Close() +} + +type Server interface { + Serve() error + Close() error +} diff --git a/transport/tuic/pool_client.go b/transport/tuic/pool_client.go index 223436cd..4a779706 100644 --- a/transport/tuic/pool_client.go +++ b/transport/tuic/pool_client.go @@ -23,15 +23,14 @@ type dialResult struct { } type PoolClient struct { - *ClientOption - - newClientOption *ClientOption - dialResultMap map[C.Dialer]dialResult - dialResultMutex *sync.Mutex - tcpClients *list.List[*Client] - tcpClientsMutex *sync.Mutex - udpClients *list.List[*Client] - udpClientsMutex *sync.Mutex + newClientOptionV4 *ClientOptionV4 + newClientOptionV5 *ClientOptionV5 + dialResultMap map[C.Dialer]dialResult + dialResultMutex *sync.Mutex + tcpClients *list.List[Client] + tcpClientsMutex *sync.Mutex + udpClients *list.List[Client] + udpClientsMutex *sync.Mutex } func (t *PoolClient) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) { @@ -99,7 +98,7 @@ func (t *PoolClient) forceClose() { } } -func (t *PoolClient) newClient(udp bool, dialer C.Dialer) *Client { +func (t *PoolClient) newClient(udp bool, dialer C.Dialer) (client Client) { clients := t.tcpClients clientsMutex := t.tcpClientsMutex if udp { @@ -110,22 +109,26 @@ func (t *PoolClient) newClient(udp bool, dialer C.Dialer) *Client { clientsMutex.Lock() defer clientsMutex.Unlock() - client := NewClient(t.newClientOption, udp) - client.dialerRef = dialer - client.lastVisited = time.Now() + if t.newClientOptionV4 != nil { + client = NewClientV4(t.newClientOptionV4, udp, dialer) + } else { + client = NewClientV5(t.newClientOptionV5, udp, dialer) + } + + client.SetLastVisited(time.Now()) clients.PushFront(client) return client } -func (t *PoolClient) getClient(udp bool, dialer C.Dialer) *Client { +func (t *PoolClient) getClient(udp bool, dialer C.Dialer) Client { clients := t.tcpClients clientsMutex := t.tcpClientsMutex if udp { clients = t.udpClients clientsMutex = t.udpClientsMutex } - var bestClient *Client + var bestClient Client func() { clientsMutex.Lock() @@ -138,11 +141,11 @@ func (t *PoolClient) getClient(udp bool, dialer C.Dialer) *Client { it = next continue } - if client.dialerRef == dialer { + if client.DialerRef() == dialer { if bestClient == nil { bestClient = client } else { - if client.openStreams.Load() < bestClient.openStreams.Load() { + if client.OpenStreams() < bestClient.OpenStreams() { bestClient = client } } @@ -152,7 +155,7 @@ func (t *PoolClient) getClient(udp bool, dialer C.Dialer) *Client { }() for it := clients.Front(); it != nil; { client := it.Value - if client != bestClient && client.openStreams.Load() == 0 && time.Now().Sub(client.lastVisited) > 30*time.Minute { + if client != bestClient && client.OpenStreams() == 0 && time.Now().Sub(client.LastVisited()) > 30*time.Minute { client.Close() next := it.Next() clients.Remove(it) @@ -165,25 +168,40 @@ func (t *PoolClient) getClient(udp bool, dialer C.Dialer) *Client { if bestClient == nil { return t.newClient(udp, dialer) } else { - bestClient.lastVisited = time.Now() + bestClient.SetLastVisited(time.Now()) return bestClient } } -func NewPoolClient(clientOption *ClientOption) *PoolClient { +func NewPoolClientV4(clientOption *ClientOptionV4) *PoolClient { p := &PoolClient{ - ClientOption: clientOption, dialResultMap: make(map[C.Dialer]dialResult), dialResultMutex: &sync.Mutex{}, - tcpClients: list.New[*Client](), + tcpClients: list.New[Client](), tcpClientsMutex: &sync.Mutex{}, - udpClients: list.New[*Client](), + udpClients: list.New[Client](), udpClientsMutex: &sync.Mutex{}, } newClientOption := *clientOption - p.newClientOption = &newClientOption + p.newClientOptionV4 = &newClientOption runtime.SetFinalizer(p, closeClientPool) - log.Debugln("New Tuic PoolClient at %p", p) + log.Debugln("New TuicV4 PoolClient at %p", p) + return p +} + +func NewPoolClientV5(clientOption *ClientOptionV5) *PoolClient { + p := &PoolClient{ + dialResultMap: make(map[C.Dialer]dialResult), + dialResultMutex: &sync.Mutex{}, + tcpClients: list.New[Client](), + tcpClientsMutex: &sync.Mutex{}, + udpClients: list.New[Client](), + udpClientsMutex: &sync.Mutex{}, + } + newClientOption := *clientOption + p.newClientOptionV5 = &newClientOption + runtime.SetFinalizer(p, closeClientPool) + log.Debugln("New TuicV5 PoolClient at %p", p) return p } diff --git a/transport/tuic/tuic.go b/transport/tuic/tuic.go new file mode 100644 index 00000000..279cec95 --- /dev/null +++ b/transport/tuic/tuic.go @@ -0,0 +1,47 @@ +package tuic + +import ( + "net" + + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/tuic/common" + v4 "github.com/Dreamacro/clash/transport/tuic/v4" + v5 "github.com/Dreamacro/clash/transport/tuic/v5" +) + +type ClientOptionV4 = v4.ClientOption +type ClientOptionV5 = v5.ClientOption + +type Client = common.Client + +func NewClientV4(clientOption *ClientOptionV4, udp bool, dialerRef C.Dialer) Client { + return v4.NewClient(clientOption, udp, dialerRef) +} + +func NewClientV5(clientOption *ClientOptionV5, udp bool, dialerRef C.Dialer) Client { + return v5.NewClient(clientOption, udp, dialerRef) +} + +type DialFunc = common.DialFunc + +var TooManyOpenStreams = common.TooManyOpenStreams + +type ServerOptionV4 = v4.ServerOption +type ServerOptionV5 = v5.ServerOption + +type Server = common.Server + +func NewServerV4(option *ServerOptionV4, pc net.PacketConn) (Server, error) { + return v4.NewServer(option, pc) +} + +func NewServerV5(option *ServerOptionV5, pc net.PacketConn) (Server, error) { + return v5.NewServer(option, pc) +} + +const DefaultStreamReceiveWindow = common.DefaultStreamReceiveWindow +const DefaultConnectionReceiveWindow = common.DefaultConnectionReceiveWindow + +var GenTKN = v4.GenTKN +var PacketOverHeadV4 = v4.PacketOverHead +var PacketOverHeadV5 = v5.PacketOverHead diff --git a/transport/tuic/client.go b/transport/tuic/v4/client.go similarity index 85% rename from transport/tuic/client.go rename to transport/tuic/v4/client.go index 6fd2a241..ae0cf473 100644 --- a/transport/tuic/client.go +++ b/transport/tuic/v4/client.go @@ -1,4 +1,4 @@ -package tuic +package v4 import ( "bufio" @@ -13,23 +13,18 @@ import ( "time" "unsafe" + atomic2 "github.com/Dreamacro/clash/common/atomic" "github.com/Dreamacro/clash/common/buf" N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" + "github.com/Dreamacro/clash/transport/tuic/common" "github.com/metacubex/quic-go" "github.com/zhangyunhao116/fastrand" ) -var ( - ClientClosed = errors.New("tuic: client closed") - TooManyOpenStreams = errors.New("tuic: too many open streams") -) - -type DialFunc func(ctx context.Context, dialer C.Dialer) (transport *quic.Transport, addr net.Addr, err error) - type ClientOption struct { TlsConfig *tls.Config QuicConfig *quic.Config @@ -57,10 +52,26 @@ type clientImpl struct { // only ready for PoolClient dialerRef C.Dialer - lastVisited time.Time + lastVisited atomic2.TypedValue[time.Time] } -func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn DialFunc) (quic.Connection, error) { +func (t *clientImpl) OpenStreams() int64 { + return t.openStreams.Load() +} + +func (t *clientImpl) DialerRef() C.Dialer { + return t.dialerRef +} + +func (t *clientImpl) LastVisited() time.Time { + return t.lastVisited.Load() +} + +func (t *clientImpl) SetLastVisited(last time.Time) { + t.lastVisited.Store(last) +} + +func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (quic.Connection, error) { t.connMutex.Lock() defer t.connMutex.Unlock() if t.quicConn != nil { @@ -80,7 +91,7 @@ func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn Di return nil, err } - SetCongestionController(quicConn, t.CongestionController) + common.SetCongestionController(quicConn, t.CongestionController) go func() { _ = t.sendAuthentication(quicConn) @@ -237,11 +248,11 @@ func (t *clientImpl) forceClose(quicConn quic.Connection, err error) { func (t *clientImpl) Close() { t.closed.Store(true) if t.openStreams.Load() == 0 { - t.forceClose(nil, ClientClosed) + t.forceClose(nil, common.ClientClosed) } } -func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) { +func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { quicConn, err := t.getQuicConn(ctx, dialer, dialFn) if err != nil { return nil, err @@ -249,9 +260,9 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta openStreams := t.openStreams.Add(1) if openStreams >= t.MaxOpenStreams { t.openStreams.Add(-1) - return nil, TooManyOpenStreams + return nil, common.TooManyOpenStreams } - stream, err := func() (stream *quicStreamConn, err error) { + stream, err := func() (stream net.Conn, err error) { defer func() { t.deferQuicConn(quicConn, err) }() @@ -265,19 +276,19 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta if err != nil { return nil, err } - stream = &quicStreamConn{ - Stream: quicStream, - lAddr: quicConn.LocalAddr(), - rAddr: quicConn.RemoteAddr(), - closeDeferFn: func() { + stream = common.NewQuicStreamConn( + quicStream, + quicConn.LocalAddr(), + quicConn.RemoteAddr(), + func() { time.AfterFunc(C.DefaultTCPTimeout, func() { openStreams := t.openStreams.Add(-1) if openStreams == 0 && t.closed.Load() { - t.forceClose(quicConn, ClientClosed) + t.forceClose(quicConn, common.ClientClosed) } }) }, - } + ) _, err = buf.WriteTo(stream) if err != nil { _ = stream.Close() @@ -361,7 +372,7 @@ func (conn *earlyConn) WriterReplaceable() bool { return true } -func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) { +func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { quicConn, err := t.getQuicConn(ctx, dialer, dialFn) if err != nil { return nil, err @@ -369,7 +380,7 @@ func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Met openStreams := t.openStreams.Add(1) if openStreams >= t.MaxOpenStreams { t.openStreams.Add(-1) - return nil, TooManyOpenStreams + return nil, common.TooManyOpenStreams } pipe1, pipe2 := net.Pipe() @@ -393,7 +404,7 @@ func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Met time.AfterFunc(C.DefaultUDPTimeout, func() { openStreams := t.openStreams.Add(-1) if openStreams == 0 && t.closed.Load() { - t.forceClose(quicConn, ClientClosed) + t.forceClose(quicConn, common.ClientClosed) } }) }, @@ -405,7 +416,7 @@ type Client struct { *clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner } -func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.Conn, error) { +func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn) if err != nil { return nil, err @@ -413,7 +424,7 @@ func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata return N.NewRefConn(conn, t), err } -func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn DialFunc) (net.PacketConn, error) { +func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn) if err != nil { return nil, err @@ -422,21 +433,22 @@ func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadat } func (t *Client) forceClose() { - t.clientImpl.forceClose(nil, ClientClosed) + t.clientImpl.forceClose(nil, common.ClientClosed) } -func NewClient(clientOption *ClientOption, udp bool) *Client { +func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client { ci := &clientImpl{ ClientOption: clientOption, udp: udp, + dialerRef: dialerRef, } c := &Client{ci} runtime.SetFinalizer(c, closeClient) - log.Debugln("New Tuic Client at %p", c) + log.Debugln("New TuicV4 Client at %p", c) return c } func closeClient(client *Client) { - log.Debugln("Close Tuic Client at %p", client) + log.Debugln("Close TuicV4 Client at %p", client) client.forceClose() } diff --git a/transport/tuic/conn.go b/transport/tuic/v4/packet.go similarity index 60% rename from transport/tuic/conn.go rename to transport/tuic/v4/packet.go index f226746d..edd872cc 100644 --- a/transport/tuic/conn.go +++ b/transport/tuic/v4/packet.go @@ -1,4 +1,4 @@ -package tuic +package v4 import ( "net" @@ -10,100 +10,8 @@ import ( N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" - "github.com/Dreamacro/clash/transport/tuic/congestion" ) -const ( - DefaultStreamReceiveWindow = 15728640 // 15 MB/s - DefaultConnectionReceiveWindow = 67108864 // 64 MB/s -) - -func SetCongestionController(quicConn quic.Connection, cc string) { - switch cc { - case "cubic": - quicConn.SetCongestionControl( - congestion.NewCubicSender( - congestion.DefaultClock{}, - congestion.GetInitialPacketSize(quicConn.RemoteAddr()), - false, - nil, - ), - ) - case "new_reno": - quicConn.SetCongestionControl( - congestion.NewCubicSender( - congestion.DefaultClock{}, - congestion.GetInitialPacketSize(quicConn.RemoteAddr()), - true, - nil, - ), - ) - case "bbr": - quicConn.SetCongestionControl( - congestion.NewBBRSender( - congestion.DefaultClock{}, - congestion.GetInitialPacketSize(quicConn.RemoteAddr()), - congestion.InitialCongestionWindow*congestion.InitialMaxDatagramSize, - congestion.DefaultBBRMaxCongestionWindow*congestion.InitialMaxDatagramSize, - ), - ) - } -} - -type quicStreamConn struct { - quic.Stream - lock sync.Mutex - lAddr net.Addr - rAddr net.Addr - - closeDeferFn func() - - closeOnce sync.Once - closeErr error -} - -func (q *quicStreamConn) Write(p []byte) (n int, err error) { - q.lock.Lock() - defer q.lock.Unlock() - return q.Stream.Write(p) -} - -func (q *quicStreamConn) Close() error { - q.closeOnce.Do(func() { - q.closeErr = q.close() - }) - return q.closeErr -} - -func (q *quicStreamConn) close() error { - if q.closeDeferFn != nil { - defer q.closeDeferFn() - } - - // https://github.com/cloudflare/cloudflared/commit/ed2bac026db46b239699ac5ce4fcf122d7cab2cd - // Make sure a possible writer does not block the lock forever. We need it, so we can close the writer - // side of the stream safely. - _ = q.Stream.SetWriteDeadline(time.Now()) - - // This lock is eventually acquired despite Write also acquiring it, because we set a deadline to writes. - q.lock.Lock() - defer q.lock.Unlock() - - // We have to clean up the receiving stream ourselves since the Close in the bottom does not handle that. - q.Stream.CancelRead(0) - return q.Stream.Close() -} - -func (q *quicStreamConn) LocalAddr() net.Addr { - return q.lAddr -} - -func (q *quicStreamConn) RemoteAddr() net.Addr { - return q.rAddr -} - -var _ net.Conn = (*quicStreamConn)(nil) - type quicStreamPacketConn struct { connId uint32 quicConn quic.Connection diff --git a/transport/tuic/protocol.go b/transport/tuic/v4/protocol.go similarity index 99% rename from transport/tuic/protocol.go rename to transport/tuic/v4/protocol.go index 472bb980..65f0f9d5 100644 --- a/transport/tuic/protocol.go +++ b/transport/tuic/v4/protocol.go @@ -1,4 +1,4 @@ -package tuic +package v4 import ( "encoding/binary" diff --git a/transport/tuic/server.go b/transport/tuic/v4/server.go similarity index 91% rename from transport/tuic/server.go rename to transport/tuic/v4/server.go index f8c4b20a..525ead17 100644 --- a/transport/tuic/server.go +++ b/transport/tuic/v4/server.go @@ -1,4 +1,4 @@ -package tuic +package v4 import ( "bufio" @@ -11,19 +11,21 @@ import ( "sync/atomic" "time" + "github.com/Dreamacro/clash/adapter/inbound" N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/utils" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/transport/socks5" + "github.com/Dreamacro/clash/transport/tuic/common" "github.com/gofrs/uuid/v5" "github.com/metacubex/quic-go" ) type ServerOption struct { - HandleTcpFn func(conn net.Conn, addr socks5.Addr) error - HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket) error + HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error + HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error TlsConfig *tls.Config QuicConfig *quic.Config @@ -55,7 +57,7 @@ func (s *Server) Serve() error { if err != nil { return err } - SetCongestionController(conn, s.CongestionController) + common.SetCongestionController(conn, s.CongestionController) h := &serverHandler{ Server: s, quicConn: conn, @@ -162,11 +164,12 @@ func (s *serverHandler) handleStream() (err error) { return err } go func() (err error) { - stream := &quicStreamConn{ - Stream: quicStream, - lAddr: s.quicConn.LocalAddr(), - rAddr: s.quicConn.RemoteAddr(), - } + stream := common.NewQuicStreamConn( + quicStream, + s.quicConn.LocalAddr(), + s.quicConn.RemoteAddr(), + nil, + ) conn := N.NewBufferedConn(stream) connect, err := ReadConnect(conn) if err != nil { @@ -224,18 +227,18 @@ func (s *serverHandler) handleUniStream() (err error) { if err != nil { return } - ok := false + authOk := false for _, tkn := range s.Tokens { if authenticate.TKN == tkn { - ok = true + authOk = true break } } s.authOnce.Do(func() { - if !ok { + if !authOk { _ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed") } - s.authOk = ok + s.authOk = authOk close(s.authCh) }) case PacketType: diff --git a/transport/tuic/v5/client.go b/transport/tuic/v5/client.go new file mode 100644 index 00000000..9b878177 --- /dev/null +++ b/transport/tuic/v5/client.go @@ -0,0 +1,386 @@ +package v5 + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "errors" + "net" + "runtime" + "sync" + "sync/atomic" + "time" + + atomic2 "github.com/Dreamacro/clash/common/atomic" + N "github.com/Dreamacro/clash/common/net" + "github.com/Dreamacro/clash/common/pool" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/log" + "github.com/Dreamacro/clash/transport/tuic/common" + + "github.com/metacubex/quic-go" + "github.com/zhangyunhao116/fastrand" +) + +type ClientOption struct { + TlsConfig *tls.Config + QuicConfig *quic.Config + Uuid [16]byte + Password string + UdpRelayMode string + CongestionController string + ReduceRtt bool + MaxUdpRelayPacketSize int + MaxOpenStreams int64 +} + +type clientImpl struct { + *ClientOption + udp bool + + quicConn quic.Connection + connMutex sync.Mutex + + openStreams atomic.Int64 + closed atomic.Bool + + udpInputMap sync.Map + + // only ready for PoolClient + dialerRef C.Dialer + lastVisited atomic2.TypedValue[time.Time] +} + +func (t *clientImpl) OpenStreams() int64 { + return t.openStreams.Load() +} + +func (t *clientImpl) DialerRef() C.Dialer { + return t.dialerRef +} + +func (t *clientImpl) LastVisited() time.Time { + return t.lastVisited.Load() +} + +func (t *clientImpl) SetLastVisited(last time.Time) { + t.lastVisited.Store(last) +} + +func (t *clientImpl) getQuicConn(ctx context.Context, dialer C.Dialer, dialFn common.DialFunc) (quic.Connection, error) { + t.connMutex.Lock() + defer t.connMutex.Unlock() + if t.quicConn != nil { + return t.quicConn, nil + } + transport, addr, err := dialFn(ctx, dialer) + if err != nil { + return nil, err + } + var quicConn quic.Connection + if t.ReduceRtt { + quicConn, err = transport.DialEarly(ctx, addr, t.TlsConfig, t.QuicConfig) + } else { + quicConn, err = transport.Dial(ctx, addr, t.TlsConfig, t.QuicConfig) + } + if err != nil { + return nil, err + } + + common.SetCongestionController(quicConn, t.CongestionController) + + go func() { + _ = t.sendAuthentication(quicConn) + }() + + if t.udp { + go func() { + _ = t.parseUDP(quicConn) + }() + } + + t.quicConn = quicConn + t.openStreams.Store(0) + return quicConn, nil +} + +func (t *clientImpl) sendAuthentication(quicConn quic.Connection) (err error) { + defer func() { + t.deferQuicConn(quicConn, err) + }() + stream, err := quicConn.OpenUniStream() + if err != nil { + return err + } + buf := pool.GetBuffer() + defer pool.PutBuffer(buf) + token, err := GenToken(quicConn.ConnectionState(), t.Uuid, t.Password) + if err != nil { + return err + } + err = NewAuthenticate(t.Uuid, token).WriteTo(buf) + if err != nil { + return err + } + _, err = buf.WriteTo(stream) + if err != nil { + return err + } + err = stream.Close() + if err != nil { + return + } + return nil +} + +func (t *clientImpl) parseUDP(quicConn quic.Connection) (err error) { + defer func() { + t.deferQuicConn(quicConn, err) + }() + switch t.UdpRelayMode { + case "quic": + for { + var stream quic.ReceiveStream + stream, err = quicConn.AcceptUniStream(context.Background()) + if err != nil { + return err + } + go func() (err error) { + var assocId uint16 + defer func() { + t.deferQuicConn(quicConn, err) + if err != nil && assocId != 0 { + if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _ = conn.Close() + } + } + } + stream.CancelRead(0) + }() + reader := bufio.NewReader(stream) + packet, err := ReadPacket(reader) + if err != nil { + return + } + assocId = packet.ASSOC_ID + if val, ok := t.udpInputMap.Load(assocId); ok { + if conn, ok := val.(net.Conn); ok { + writer := bufio.NewWriterSize(conn, packet.BytesLen()) + _ = packet.WriteTo(writer) + _ = writer.Flush() + } + } + return + }() + } + default: // native + for { + var message []byte + message, err = quicConn.ReceiveMessage() + if err != nil { + return err + } + go func() (err error) { + var assocId uint16 + defer func() { + t.deferQuicConn(quicConn, err) + if err != nil && assocId != 0 { + if val, ok := t.udpInputMap.LoadAndDelete(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _ = conn.Close() + } + } + } + }() + buffer := bytes.NewBuffer(message) + packet, err := ReadPacket(buffer) + if err != nil { + return + } + assocId = packet.ASSOC_ID + if val, ok := t.udpInputMap.Load(assocId); ok { + if conn, ok := val.(net.Conn); ok { + _, _ = conn.Write(message) + } + } + return + }() + } + } +} + +func (t *clientImpl) deferQuicConn(quicConn quic.Connection, err error) { + var netError net.Error + if err != nil && errors.As(err, &netError) { + t.forceClose(quicConn, err) + } +} + +func (t *clientImpl) forceClose(quicConn quic.Connection, err error) { + t.connMutex.Lock() + defer t.connMutex.Unlock() + if quicConn == nil { + quicConn = t.quicConn + } + if quicConn != nil { + if quicConn == t.quicConn { + t.quicConn = nil + } + } + errStr := "" + if err != nil { + errStr = err.Error() + } + if quicConn != nil { + _ = quicConn.CloseWithError(ProtocolError, errStr) + } + udpInputMap := &t.udpInputMap + udpInputMap.Range(func(key, value any) bool { + if conn, ok := value.(net.Conn); ok { + _ = conn.Close() + } + udpInputMap.Delete(key) + return true + }) +} + +func (t *clientImpl) Close() { + t.closed.Store(true) + if t.openStreams.Load() == 0 { + t.forceClose(nil, common.ClientClosed) + } +} + +func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { + quicConn, err := t.getQuicConn(ctx, dialer, dialFn) + if err != nil { + return nil, err + } + openStreams := t.openStreams.Add(1) + if openStreams >= t.MaxOpenStreams { + t.openStreams.Add(-1) + return nil, common.TooManyOpenStreams + } + stream, err := func() (stream net.Conn, err error) { + defer func() { + t.deferQuicConn(quicConn, err) + }() + buf := pool.GetBuffer() + defer pool.PutBuffer(buf) + err = NewConnect(NewAddress(metadata)).WriteTo(buf) + if err != nil { + return nil, err + } + quicStream, err := quicConn.OpenStream() + if err != nil { + return nil, err + } + stream = common.NewQuicStreamConn( + quicStream, + quicConn.LocalAddr(), + quicConn.RemoteAddr(), + func() { + time.AfterFunc(C.DefaultTCPTimeout, func() { + openStreams := t.openStreams.Add(-1) + if openStreams == 0 && t.closed.Load() { + t.forceClose(quicConn, common.ClientClosed) + } + }) + }, + ) + _, err = buf.WriteTo(stream) + if err != nil { + _ = stream.Close() + return nil, err + } + return stream, err + }() + if err != nil { + return nil, err + } + + return stream, nil +} + +func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { + quicConn, err := t.getQuicConn(ctx, dialer, dialFn) + if err != nil { + return nil, err + } + openStreams := t.openStreams.Add(1) + if openStreams >= t.MaxOpenStreams { + t.openStreams.Add(-1) + return nil, common.TooManyOpenStreams + } + + pipe1, pipe2 := net.Pipe() + var connId uint16 + for { + connId = uint16(fastrand.Intn(0xFFFF)) + _, loaded := t.udpInputMap.LoadOrStore(connId, pipe1) + if !loaded { + break + } + } + pc := &quicStreamPacketConn{ + connId: connId, + quicConn: quicConn, + inputConn: N.NewBufferedConn(pipe2), + udpRelayMode: t.UdpRelayMode, + maxUdpRelayPacketSize: t.MaxUdpRelayPacketSize, + deferQuicConnFn: t.deferQuicConn, + closeDeferFn: func() { + t.udpInputMap.Delete(connId) + time.AfterFunc(C.DefaultUDPTimeout, func() { + openStreams := t.openStreams.Add(-1) + if openStreams == 0 && t.closed.Load() { + t.forceClose(quicConn, common.ClientClosed) + } + }) + }, + } + return pc, nil +} + +type Client struct { + *clientImpl // use an independent pointer to let Finalizer can work no matter somewhere handle an influence in clientImpl inner +} + +func (t *Client) DialContextWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.Conn, error) { + conn, err := t.clientImpl.DialContextWithDialer(ctx, metadata, dialer, dialFn) + if err != nil { + return nil, err + } + return N.NewRefConn(conn, t), err +} + +func (t *Client) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { + pc, err := t.clientImpl.ListenPacketWithDialer(ctx, metadata, dialer, dialFn) + if err != nil { + return nil, err + } + return N.NewRefPacketConn(pc, t), nil +} + +func (t *Client) forceClose() { + t.clientImpl.forceClose(nil, common.ClientClosed) +} + +func NewClient(clientOption *ClientOption, udp bool, dialerRef C.Dialer) *Client { + ci := &clientImpl{ + ClientOption: clientOption, + udp: udp, + dialerRef: dialerRef, + } + c := &Client{ci} + runtime.SetFinalizer(c, closeClient) + log.Debugln("New TuicV5 Client at %p", c) + return c +} + +func closeClient(client *Client) { + log.Debugln("Close TuicV5 Client at %p", client) + client.forceClose() +} diff --git a/transport/tuic/v5/frag.go b/transport/tuic/v5/frag.go new file mode 100644 index 00000000..30b7b3f5 --- /dev/null +++ b/transport/tuic/v5/frag.go @@ -0,0 +1,80 @@ +package v5 + +import ( + "bytes" + + "github.com/metacubex/quic-go" +) + +func fragWriteNative(quicConn quic.Connection, packet Packet, buf *bytes.Buffer, fragSize int) (err error) { + fullPayload := packet.DATA + off := 0 + fragID := uint8(0) + fragCount := uint8((len(fullPayload) + fragSize - 1) / fragSize) // round up + packet.FRAG_TOTAL = fragCount + for off < len(fullPayload) { + payloadSize := len(fullPayload) - off + if payloadSize > fragSize { + payloadSize = fragSize + } + frag := packet + frag.FRAG_ID = fragID + frag.SIZE = uint16(payloadSize) + frag.DATA = fullPayload[off : off+payloadSize] + off += payloadSize + fragID++ + buf.Reset() + err = frag.WriteTo(buf) + if err != nil { + return + } + data := buf.Bytes() + err = quicConn.SendMessage(data) + if err != nil { + return + } + packet.ADDR.TYPE = AtypNone // avoid "fragment 2/2: address in non-first fragment" + } + return +} + +type deFragger struct { + pkgID uint16 + frags []*Packet + count uint8 +} + +func (d *deFragger) Feed(m Packet) *Packet { + if m.FRAG_TOTAL <= 1 { + return &m + } + if m.FRAG_ID >= m.FRAG_TOTAL { + // wtf is this? + return nil + } + if d.count == 0 || m.PKT_ID != d.pkgID { + // new message, clear previous state + d.pkgID = m.PKT_ID + d.frags = make([]*Packet, m.FRAG_TOTAL) + d.count = 1 + d.frags[m.FRAG_ID] = &m + } else if d.frags[m.FRAG_ID] == nil { + d.frags[m.FRAG_ID] = &m + d.count++ + if int(d.count) == len(d.frags) { + // all fragments received, assemble + var data []byte + for _, frag := range d.frags { + data = append(data, frag.DATA...) + } + p := d.frags[0] // recover from first fragment + p.SIZE = uint16(len(data)) + p.DATA = data + p.FRAG_ID = 0 + p.FRAG_TOTAL = 1 + d.count = 0 + return p + } + } + return nil +} diff --git a/transport/tuic/v5/packet.go b/transport/tuic/v5/packet.go new file mode 100644 index 00000000..50c602eb --- /dev/null +++ b/transport/tuic/v5/packet.go @@ -0,0 +1,208 @@ +package v5 + +import ( + "errors" + "net" + "sync" + "sync/atomic" + "time" + + N "github.com/Dreamacro/clash/common/net" + "github.com/Dreamacro/clash/common/pool" + + "github.com/metacubex/quic-go" + "github.com/zhangyunhao116/fastrand" +) + +type quicStreamPacketConn struct { + connId uint16 + quicConn quic.Connection + inputConn *N.BufferedConn + + udpRelayMode string + maxUdpRelayPacketSize int + + deferQuicConnFn func(quicConn quic.Connection, err error) + closeDeferFn func() + writeClosed *atomic.Bool + + closeOnce sync.Once + closeErr error + closed bool + + deFragger +} + +func (q *quicStreamPacketConn) Close() error { + q.closeOnce.Do(func() { + q.closed = true + q.closeErr = q.close() + }) + return q.closeErr +} + +func (q *quicStreamPacketConn) close() (err error) { + if q.closeDeferFn != nil { + defer q.closeDeferFn() + } + if q.deferQuicConnFn != nil { + defer func() { + q.deferQuicConnFn(q.quicConn, err) + }() + } + if q.inputConn != nil { + _ = q.inputConn.Close() + q.inputConn = nil + + buf := pool.GetBuffer() + defer pool.PutBuffer(buf) + err = NewDissociate(q.connId).WriteTo(buf) + if err != nil { + return + } + var stream quic.SendStream + stream, err = q.quicConn.OpenUniStream() + if err != nil { + return + } + _, err = buf.WriteTo(stream) + if err != nil { + return + } + err = stream.Close() + if err != nil { + return + } + } + return +} + +func (q *quicStreamPacketConn) SetDeadline(t time.Time) error { + //TODO implement me + return nil +} + +func (q *quicStreamPacketConn) SetReadDeadline(t time.Time) error { + if q.inputConn != nil { + return q.inputConn.SetReadDeadline(t) + } + return nil +} + +func (q *quicStreamPacketConn) SetWriteDeadline(t time.Time) error { + //TODO implement me + return nil +} + +func (q *quicStreamPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if q.inputConn != nil { + for { + var packet Packet + packet, err = ReadPacket(q.inputConn) + if err != nil { + return + } + if packetPtr := q.deFragger.Feed(packet); packetPtr != nil { + n = copy(p, packet.DATA) + addr = packetPtr.ADDR.UDPAddr() + return + } + } + } else { + err = net.ErrClosed + } + return +} + +func (q *quicStreamPacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + if q.inputConn != nil { + for { + var packet Packet + packet, err = ReadPacket(q.inputConn) + if err != nil { + return + } + if packetPtr := q.deFragger.Feed(packet); packetPtr != nil { + data = packetPtr.DATA + addr = packetPtr.ADDR.UDPAddr() + return + } + } + } else { + err = net.ErrClosed + } + return +} + +func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if len(p) > 0xffff { // uint16 max + return 0, quic.ErrMessageTooLarge(0xffff) + } + if q.closed { + return 0, net.ErrClosed + } + if q.writeClosed != nil && q.writeClosed.Load() { + _ = q.Close() + return 0, net.ErrClosed + } + if q.deferQuicConnFn != nil { + defer func() { + q.deferQuicConnFn(q.quicConn, err) + }() + } + buf := pool.GetBuffer() + defer pool.PutBuffer(buf) + address, err := NewAddressNetAddr(addr) + if err != nil { + return + } + pktId := uint16(fastrand.Uint32()) + packet := NewPacket(q.connId, pktId, 1, 0, uint16(len(p)), address, p) + switch q.udpRelayMode { + case "quic": + err = packet.WriteTo(buf) + if err != nil { + return + } + var stream quic.SendStream + stream, err = q.quicConn.OpenUniStream() + if err != nil { + return + } + defer stream.Close() + _, err = buf.WriteTo(stream) + if err != nil { + return + } + default: // native + if len(p) > q.maxUdpRelayPacketSize { + err = fragWriteNative(q.quicConn, packet, buf, q.maxUdpRelayPacketSize) + if err != nil { + return + } + } + err = packet.WriteTo(buf) + if err != nil { + return + } + data := buf.Bytes() + err = q.quicConn.SendMessage(data) + + var tooLarge quic.ErrMessageTooLarge + if errors.As(err, &tooLarge) { + err = fragWriteNative(q.quicConn, packet, buf, int(tooLarge)-PacketOverHead) + } + if err != nil { + return + } + } + n = len(p) + + return +} + +func (q *quicStreamPacketConn) LocalAddr() net.Addr { + return q.quicConn.LocalAddr() +} + +var _ net.PacketConn = (*quicStreamPacketConn)(nil) diff --git a/transport/tuic/v5/protocol.go b/transport/tuic/v5/protocol.go new file mode 100644 index 00000000..f2849746 --- /dev/null +++ b/transport/tuic/v5/protocol.go @@ -0,0 +1,651 @@ +package v5 + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "net/netip" + "strconv" + + "github.com/Dreamacro/clash/common/utils" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/socks5" + + "github.com/metacubex/quic-go" +) + +type BufferedReader interface { + io.Reader + io.ByteReader +} + +type BufferedWriter interface { + io.Writer + io.ByteWriter +} + +type CommandType byte + +const ( + AuthenticateType = CommandType(0x00) + ConnectType = CommandType(0x01) + PacketType = CommandType(0x02) + DissociateType = CommandType(0x03) + HeartbeatType = CommandType(0x04) + ResponseType = CommandType(0xff) +) + +func (c CommandType) String() string { + switch c { + case AuthenticateType: + return "Authenticate" + case ConnectType: + return "Connect" + case PacketType: + return "Packet" + case DissociateType: + return "Dissociate" + case HeartbeatType: + return "Heartbeat" + case ResponseType: + return "Response" + default: + return fmt.Sprintf("UnknowCommand: %#x", byte(c)) + } +} + +func (c CommandType) BytesLen() int { + return 1 +} + +type CommandHead struct { + VER byte + TYPE CommandType +} + +func NewCommandHead(TYPE CommandType) CommandHead { + return CommandHead{ + VER: 0x05, + TYPE: TYPE, + } +} + +func ReadCommandHead(reader BufferedReader) (c CommandHead, err error) { + c.VER, err = reader.ReadByte() + if err != nil { + return + } + TYPE, err := reader.ReadByte() + if err != nil { + return + } + c.TYPE = CommandType(TYPE) + return +} + +func (c CommandHead) WriteTo(writer BufferedWriter) (err error) { + err = writer.WriteByte(c.VER) + if err != nil { + return + } + err = writer.WriteByte(byte(c.TYPE)) + if err != nil { + return + } + return +} + +func (c CommandHead) BytesLen() int { + return 1 + c.TYPE.BytesLen() +} + +type Authenticate struct { + CommandHead + UUID [16]byte + TOKEN [32]byte +} + +func NewAuthenticate(UUID [16]byte, TOKEN [32]byte) Authenticate { + return Authenticate{ + CommandHead: NewCommandHead(AuthenticateType), + UUID: UUID, + TOKEN: TOKEN, + } +} + +func ReadAuthenticateWithHead(head CommandHead, reader BufferedReader) (c Authenticate, err error) { + c.CommandHead = head + if c.CommandHead.TYPE != AuthenticateType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return + } + _, err = io.ReadFull(reader, c.UUID[:]) + if err != nil { + return + } + _, err = io.ReadFull(reader, c.TOKEN[:]) + if err != nil { + return + } + return +} + +func ReadAuthenticate(reader BufferedReader) (c Authenticate, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadAuthenticateWithHead(head, reader) +} + +func GenToken(state quic.ConnectionState, uuid [16]byte, password string) (token [32]byte, err error) { + var tokenBytes []byte + tokenBytes, err = state.TLS.ExportKeyingMaterial(utils.StringFromImmutableBytes(uuid[:]), utils.ImmutableBytesFromString(password), 32) + if err != nil { + return + } + copy(token[:], tokenBytes) + return +} + +func (c Authenticate) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + _, err = writer.Write(c.UUID[:]) + if err != nil { + return + } + _, err = writer.Write(c.TOKEN[:]) + if err != nil { + return + } + return +} + +func (c Authenticate) BytesLen() int { + return c.CommandHead.BytesLen() + 16 + 32 +} + +type Connect struct { + CommandHead + ADDR Address +} + +func NewConnect(ADDR Address) Connect { + return Connect{ + CommandHead: NewCommandHead(ConnectType), + ADDR: ADDR, + } +} + +func ReadConnectWithHead(head CommandHead, reader BufferedReader) (c Connect, err error) { + c.CommandHead = head + if c.CommandHead.TYPE != ConnectType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return + } + c.ADDR, err = ReadAddress(reader) + if err != nil { + return + } + return +} + +func ReadConnect(reader BufferedReader) (c Connect, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadConnectWithHead(head, reader) +} + +func (c Connect) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + err = c.ADDR.WriteTo(writer) + if err != nil { + return + } + return +} + +func (c Connect) BytesLen() int { + return c.CommandHead.BytesLen() + c.ADDR.BytesLen() +} + +type Packet struct { + CommandHead + ASSOC_ID uint16 + PKT_ID uint16 + FRAG_TOTAL uint8 + FRAG_ID uint8 + SIZE uint16 + ADDR Address + DATA []byte +} + +func NewPacket(ASSOC_ID uint16, PKT_ID uint16, FRGA_TOTAL uint8, FRAG_ID uint8, SIZE uint16, ADDR Address, DATA []byte) Packet { + return Packet{ + CommandHead: NewCommandHead(PacketType), + ASSOC_ID: ASSOC_ID, + PKT_ID: PKT_ID, + FRAG_ID: FRAG_ID, + FRAG_TOTAL: FRGA_TOTAL, + SIZE: SIZE, + ADDR: ADDR, + DATA: DATA, + } +} + +func ReadPacketWithHead(head CommandHead, reader BufferedReader) (c Packet, err error) { + c.CommandHead = head + if c.CommandHead.TYPE != PacketType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return + } + err = binary.Read(reader, binary.BigEndian, &c.ASSOC_ID) + if err != nil { + return + } + err = binary.Read(reader, binary.BigEndian, &c.PKT_ID) + if err != nil { + return + } + err = binary.Read(reader, binary.BigEndian, &c.FRAG_TOTAL) + if err != nil { + return + } + err = binary.Read(reader, binary.BigEndian, &c.FRAG_ID) + if err != nil { + return + } + err = binary.Read(reader, binary.BigEndian, &c.SIZE) + if err != nil { + return + } + c.ADDR, err = ReadAddress(reader) + if err != nil { + return + } + c.DATA = make([]byte, c.SIZE) + _, err = io.ReadFull(reader, c.DATA) + if err != nil { + return + } + return +} + +func ReadPacket(reader BufferedReader) (c Packet, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadPacketWithHead(head, reader) +} + +func (c Packet) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.ASSOC_ID) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.PKT_ID) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.FRAG_TOTAL) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.FRAG_ID) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.SIZE) + if err != nil { + return + } + err = c.ADDR.WriteTo(writer) + if err != nil { + return + } + _, err = writer.Write(c.DATA) + if err != nil { + return + } + return +} + +func (c Packet) BytesLen() int { + return c.CommandHead.BytesLen() + 4 + 2 + c.ADDR.BytesLen() + len(c.DATA) +} + +var PacketOverHead = NewPacket(0, 0, 0, 0, 0, NewAddressAddrPort(netip.AddrPortFrom(netip.IPv6Unspecified(), 0)), nil).BytesLen() + +type Dissociate struct { + CommandHead + ASSOC_ID uint16 +} + +func NewDissociate(ASSOC_ID uint16) Dissociate { + return Dissociate{ + CommandHead: NewCommandHead(DissociateType), + ASSOC_ID: ASSOC_ID, + } +} + +func ReadDissociateWithHead(head CommandHead, reader BufferedReader) (c Dissociate, err error) { + c.CommandHead = head + if c.CommandHead.TYPE != DissociateType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return + } + err = binary.Read(reader, binary.BigEndian, &c.ASSOC_ID) + if err != nil { + return + } + return +} + +func ReadDissociate(reader BufferedReader) (c Dissociate, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadDissociateWithHead(head, reader) +} + +func (c Dissociate) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.ASSOC_ID) + if err != nil { + return + } + return +} + +func (c Dissociate) BytesLen() int { + return c.CommandHead.BytesLen() + 4 +} + +type Heartbeat struct { + CommandHead +} + +func NewHeartbeat() Heartbeat { + return Heartbeat{ + CommandHead: NewCommandHead(HeartbeatType), + } +} + +func ReadHeartbeatWithHead(head CommandHead, reader BufferedReader) (c Heartbeat, err error) { + c.CommandHead = head + if c.CommandHead.TYPE != HeartbeatType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return + } + return +} + +func ReadHeartbeat(reader BufferedReader) (c Heartbeat, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadHeartbeatWithHead(head, reader) +} + +type Response struct { + CommandHead + REP byte +} + +func NewResponse(REP byte) Response { + return Response{ + CommandHead: NewCommandHead(ResponseType), + REP: REP, + } +} + +func NewResponseSucceed() Response { + return NewResponse(0x00) +} + +func NewResponseFailed() Response { + return NewResponse(0xff) +} + +func ReadResponseWithHead(head CommandHead, reader BufferedReader) (c Response, err error) { + c.CommandHead = head + if c.CommandHead.TYPE != ResponseType { + err = fmt.Errorf("error command type: %s", c.CommandHead.TYPE) + return + } + c.REP, err = reader.ReadByte() + if err != nil { + return + } + return +} + +func ReadResponse(reader BufferedReader) (c Response, err error) { + head, err := ReadCommandHead(reader) + if err != nil { + return + } + return ReadResponseWithHead(head, reader) +} + +func (c Response) WriteTo(writer BufferedWriter) (err error) { + err = c.CommandHead.WriteTo(writer) + if err != nil { + return + } + err = writer.WriteByte(c.REP) + if err != nil { + return + } + return +} + +func (c Response) IsSucceed() bool { + return c.REP == 0x00 +} + +func (c Response) IsFailed() bool { + return c.REP == 0xff +} + +func (c Response) BytesLen() int { + return c.CommandHead.BytesLen() + 1 +} + +// Addr types +const ( + AtypDomainName byte = 0 + AtypIPv4 byte = 1 + AtypIPv6 byte = 2 + AtypNone byte = 255 // Address type None is used in Packet commands that is not the first fragment of a UDP packet. +) + +type Address struct { + TYPE byte + ADDR []byte + PORT uint16 +} + +func NewAddress(metadata *C.Metadata) Address { + var addrType byte + var addr []byte + switch metadata.AddrType() { + case socks5.AtypIPv4: + addrType = AtypIPv4 + addr = metadata.DstIP.AsSlice() + case socks5.AtypIPv6: + addrType = AtypIPv6 + addr = metadata.DstIP.AsSlice() + case socks5.AtypDomainName: + addrType = AtypDomainName + addr = make([]byte, len(metadata.Host)+1) + addr[0] = byte(len(metadata.Host)) + copy(addr[1:], metadata.Host) + } + + port, _ := strconv.ParseUint(metadata.DstPort, 10, 16) + + return Address{ + TYPE: addrType, + ADDR: addr, + PORT: uint16(port), + } +} + +func NewAddressNetAddr(addr net.Addr) (Address, error) { + if addr, ok := addr.(interface{ AddrPort() netip.AddrPort }); ok { + if addrPort := addr.AddrPort(); addrPort.IsValid() { // sing's M.Socksaddr maybe return an invalid AddrPort if it's a DomainName + return NewAddressAddrPort(addrPort), nil + } + } + addrStr := addr.String() + if addrPort, err := netip.ParseAddrPort(addrStr); err == nil { + return NewAddressAddrPort(addrPort), nil + } + metadata := &C.Metadata{} + if err := metadata.SetRemoteAddress(addrStr); err != nil { + return Address{}, err + } + return NewAddress(metadata), nil +} + +func NewAddressAddrPort(addrPort netip.AddrPort) Address { + var addrType byte + port := addrPort.Port() + addr := addrPort.Addr().Unmap() + if addr.Is4() { + addrType = AtypIPv4 + } else { + addrType = AtypIPv6 + } + return Address{ + TYPE: addrType, + ADDR: addr.AsSlice(), + PORT: port, + } +} + +func ReadAddress(reader BufferedReader) (c Address, err error) { + c.TYPE, err = reader.ReadByte() + if err != nil { + return + } + switch c.TYPE { + case AtypIPv4: + c.ADDR = make([]byte, net.IPv4len) + _, err = io.ReadFull(reader, c.ADDR) + if err != nil { + return + } + case AtypIPv6: + c.ADDR = make([]byte, net.IPv6len) + _, err = io.ReadFull(reader, c.ADDR) + if err != nil { + return + } + case AtypDomainName: + var addrLen byte + addrLen, err = reader.ReadByte() + if err != nil { + return + } + c.ADDR = make([]byte, addrLen+1) + c.ADDR[0] = addrLen + _, err = io.ReadFull(reader, c.ADDR[1:]) + if err != nil { + return + } + } + + if c.TYPE == AtypNone { + return + } + err = binary.Read(reader, binary.BigEndian, &c.PORT) + if err != nil { + return + } + return +} + +func (c Address) WriteTo(writer BufferedWriter) (err error) { + err = writer.WriteByte(c.TYPE) + if err != nil { + return + } + if c.TYPE == AtypNone { + return + } + _, err = writer.Write(c.ADDR[:]) + if err != nil { + return + } + err = binary.Write(writer, binary.BigEndian, c.PORT) + if err != nil { + return + } + return +} + +func (c Address) String() string { + switch c.TYPE { + case AtypDomainName: + return net.JoinHostPort(string(c.ADDR[1:]), strconv.Itoa(int(c.PORT))) + default: + addr, _ := netip.AddrFromSlice(c.ADDR) + addrPort := netip.AddrPortFrom(addr, c.PORT) + return addrPort.String() + } +} + +func (c Address) SocksAddr() socks5.Addr { + addr := make([]byte, 1+len(c.ADDR)+2) + switch c.TYPE { + case AtypIPv4: + addr[0] = socks5.AtypIPv4 + case AtypIPv6: + addr[0] = socks5.AtypIPv6 + case AtypDomainName: + addr[0] = socks5.AtypDomainName + } + copy(addr[1:], c.ADDR) + binary.BigEndian.PutUint16(addr[len(addr)-2:], c.PORT) + return addr +} + +func (c Address) UDPAddr() *net.UDPAddr { + return &net.UDPAddr{ + IP: c.ADDR, + Port: int(c.PORT), + Zone: "", + } +} + +func (c Address) BytesLen() int { + return 1 + len(c.ADDR) + 2 +} + +const ( + ProtocolError = quic.ApplicationErrorCode(0xfffffff0) + AuthenticationFailed = quic.ApplicationErrorCode(0xfffffff1) + AuthenticationTimeout = quic.ApplicationErrorCode(0xfffffff2) + BadCommand = quic.ApplicationErrorCode(0xfffffff3) +) diff --git a/transport/tuic/v5/server.go b/transport/tuic/v5/server.go new file mode 100644 index 00000000..3e3dc52f --- /dev/null +++ b/transport/tuic/v5/server.go @@ -0,0 +1,303 @@ +package v5 + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/Dreamacro/clash/adapter/inbound" + N "github.com/Dreamacro/clash/common/net" + "github.com/Dreamacro/clash/common/utils" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/socks5" + "github.com/Dreamacro/clash/transport/tuic/common" + + "github.com/gofrs/uuid/v5" + "github.com/metacubex/quic-go" +) + +type ServerOption struct { + HandleTcpFn func(conn net.Conn, addr socks5.Addr, additions ...inbound.Addition) error + HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket, additions ...inbound.Addition) error + + TlsConfig *tls.Config + QuicConfig *quic.Config + Users map[[16]byte]string + CongestionController string + AuthenticationTimeout time.Duration + MaxUdpRelayPacketSize int +} + +type Server struct { + *ServerOption + listener *quic.EarlyListener +} + +func NewServer(option *ServerOption, pc net.PacketConn) (*Server, error) { + listener, err := quic.ListenEarly(pc, option.TlsConfig, option.QuicConfig) + if err != nil { + return nil, err + } + return &Server{ + ServerOption: option, + listener: listener, + }, err +} + +func (s *Server) Serve() error { + for { + conn, err := s.listener.Accept(context.Background()) + if err != nil { + return err + } + common.SetCongestionController(conn, s.CongestionController) + h := &serverHandler{ + Server: s, + quicConn: conn, + uuid: utils.NewUUIDV4(), + authCh: make(chan struct{}), + } + go h.handle() + } +} + +func (s *Server) Close() error { + return s.listener.Close() +} + +type serverHandler struct { + *Server + quicConn quic.EarlyConnection + uuid uuid.UUID + + authCh chan struct{} + authOk bool + authUUID string + authOnce sync.Once + + udpInputMap sync.Map +} + +func (s *serverHandler) handle() { + go func() { + _ = s.handleUniStream() + }() + go func() { + _ = s.handleStream() + }() + go func() { + _ = s.handleMessage() + }() + + <-s.quicConn.HandshakeComplete() + time.AfterFunc(s.AuthenticationTimeout, func() { + s.authOnce.Do(func() { + _ = s.quicConn.CloseWithError(AuthenticationTimeout, "AuthenticationTimeout") + s.authOk = false + close(s.authCh) + }) + }) +} + +func (s *serverHandler) handleMessage() (err error) { + for { + var message []byte + message, err = s.quicConn.ReceiveMessage() + if err != nil { + return err + } + go func() (err error) { + buffer := bytes.NewBuffer(message) + packet, err := ReadPacket(buffer) + if err != nil { + return + } + return s.parsePacket(packet, "native") + }() + } +} + +func (s *serverHandler) parsePacket(packet Packet, udpRelayMode string) (err error) { + <-s.authCh + if !s.authOk { + return + } + var assocId uint16 + + assocId = packet.ASSOC_ID + + v, _ := s.udpInputMap.LoadOrStore(assocId, &serverUDPInput{}) + input := v.(*serverUDPInput) + if input.writeClosed.Load() { + return nil + } + packetPtr := input.Feed(packet) + if packetPtr == nil { + return + } + + pc := &quicStreamPacketConn{ + connId: assocId, + quicConn: s.quicConn, + inputConn: nil, + udpRelayMode: udpRelayMode, + maxUdpRelayPacketSize: s.MaxUdpRelayPacketSize, + deferQuicConnFn: nil, + closeDeferFn: nil, + writeClosed: &input.writeClosed, + } + + return s.HandleUdpFn(packetPtr.ADDR.SocksAddr(), &serverUDPPacket{ + pc: pc, + packet: packetPtr, + rAddr: N.NewCustomAddr("tuic", fmt.Sprintf("tuic-%s-%d", s.uuid, assocId), s.quicConn.RemoteAddr()), // for tunnel's handleUDPConn + }, inbound.WithInUser(s.authUUID)) +} + +func (s *serverHandler) handleStream() (err error) { + for { + var quicStream quic.Stream + quicStream, err = s.quicConn.AcceptStream(context.Background()) + if err != nil { + return err + } + go func() (err error) { + stream := common.NewQuicStreamConn( + quicStream, + s.quicConn.LocalAddr(), + s.quicConn.RemoteAddr(), + nil, + ) + conn := N.NewBufferedConn(stream) + connect, err := ReadConnect(conn) + if err != nil { + return err + } + <-s.authCh + if !s.authOk { + return conn.Close() + } + + err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr(), inbound.WithInUser(s.authUUID)) + if err != nil { + _ = conn.Close() + return err + } + return + }() + } +} + +func (s *serverHandler) handleUniStream() (err error) { + for { + var stream quic.ReceiveStream + stream, err = s.quicConn.AcceptUniStream(context.Background()) + if err != nil { + return err + } + go func() (err error) { + defer func() { + stream.CancelRead(0) + }() + reader := bufio.NewReader(stream) + commandHead, err := ReadCommandHead(reader) + if err != nil { + return + } + switch commandHead.TYPE { + case AuthenticateType: + var authenticate Authenticate + authenticate, err = ReadAuthenticateWithHead(commandHead, reader) + if err != nil { + return + } + authOk := false + var authUUID uuid.UUID + var token [32]byte + if password, ok := s.Users[authenticate.UUID]; ok { + token, err = GenToken(s.quicConn.ConnectionState(), authenticate.UUID, password) + if err != nil { + return + } + if token == authenticate.TOKEN { + authOk = true + authUUID = authenticate.UUID + } + } + s.authOnce.Do(func() { + if !authOk { + _ = s.quicConn.CloseWithError(AuthenticationFailed, "AuthenticationFailed") + } + s.authOk = authOk + s.authUUID = authUUID.String() + close(s.authCh) + }) + case PacketType: + var packet Packet + packet, err = ReadPacketWithHead(commandHead, reader) + if err != nil { + return + } + return s.parsePacket(packet, "quic") + case DissociateType: + var disassociate Dissociate + disassociate, err = ReadDissociateWithHead(commandHead, reader) + if err != nil { + return + } + if v, loaded := s.udpInputMap.LoadAndDelete(disassociate.ASSOC_ID); loaded { + input := v.(*serverUDPInput) + input.writeClosed.Store(true) + } + case HeartbeatType: + var heartbeat Heartbeat + heartbeat, err = ReadHeartbeatWithHead(commandHead, reader) + if err != nil { + return + } + heartbeat.BytesLen() + } + return + }() + } +} + +type serverUDPInput struct { + writeClosed atomic.Bool + deFragger +} + +type serverUDPPacket struct { + pc *quicStreamPacketConn + packet *Packet + rAddr net.Addr +} + +func (s *serverUDPPacket) InAddr() net.Addr { + return s.pc.LocalAddr() +} + +func (s *serverUDPPacket) LocalAddr() net.Addr { + return s.rAddr +} + +func (s *serverUDPPacket) Data() []byte { + return s.packet.DATA +} + +func (s *serverUDPPacket) WriteBack(b []byte, addr net.Addr) (n int, err error) { + return s.pc.WriteTo(b, addr) +} + +func (s *serverUDPPacket) Drop() { + s.packet.DATA = nil +} + +var _ C.UDPPacket = (*serverUDPPacket)(nil) +var _ C.UDPPacketInAddr = (*serverUDPPacket)(nil)