chore: better tuic conn close

This commit is contained in:
gVisor bot 2022-11-25 11:32:05 +08:00
parent b2939ad863
commit 25540e6c96
2 changed files with 26 additions and 11 deletions

View File

@ -9,6 +9,7 @@ import (
"fmt" "fmt"
"net" "net"
"os" "os"
"runtime"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -199,6 +200,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
RequestTimeout: option.RequestTimeout, RequestTimeout: option.RequestTimeout,
} }
clientMap[o] = client clientMap[o] = client
runtime.SetFinalizer(client, closeTuicClient)
return client return client
} }
@ -214,3 +216,7 @@ func NewTuic(option TuicOption) (*Tuic, error) {
getClient: getClient, getClient: getClient,
}, nil }, nil
} }
func closeTuicClient(client *tuic.Client) {
client.Close(nil)
}

View File

@ -197,6 +197,15 @@ func (t *Client) deferQuicConn(quicConn quic.Connection, err error) {
t.connMutex.Lock() t.connMutex.Lock()
defer t.connMutex.Unlock() defer t.connMutex.Unlock()
if t.quicConn == quicConn { if t.quicConn == quicConn {
t.Close(err)
}
}
}
func (t *Client) Close(err error) {
quicConn := t.quicConn
if quicConn != nil {
_ = t.quicConn.CloseWithError(ProtocolError, err.Error())
t.udpInputMap.Range(func(key, value any) bool { t.udpInputMap.Range(func(key, value any) bool {
if conn, ok := value.(net.Conn); ok { if conn, ok := value.(net.Conn); ok {
_ = conn.Close() _ = conn.Close()
@ -207,7 +216,6 @@ func (t *Client) deferQuicConn(quicConn quic.Connection, err error) {
t.quicConn = nil t.quicConn = nil
} }
} }
}
func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) { func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn func(ctx context.Context) (net.PacketConn, net.Addr, error)) (net.Conn, error) {
quicConn, err := t.getQuicConn(ctx, dialFn) quicConn, err := t.getQuicConn(ctx, dialFn)
@ -237,7 +245,7 @@ func (t *Client) DialContext(ctx context.Context, metadata *C.Metadata, dialFn f
if t.RequestTimeout > 0 { if t.RequestTimeout > 0 {
_ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond)) _ = stream.SetReadDeadline(time.Now().Add(time.Duration(t.RequestTimeout) * time.Millisecond))
} }
conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr()}) conn := N.NewBufferedConn(&quicStreamConn{stream, quicConn.LocalAddr(), quicConn.RemoteAddr(), t})
response, err := ReadResponse(conn) response, err := ReadResponse(conn)
if err != nil { if err != nil {
return nil, err return nil, err
@ -254,6 +262,7 @@ type quicStreamConn struct {
quic.Stream quic.Stream
lAddr net.Addr lAddr net.Addr
rAddr net.Addr rAddr net.Addr
client *Client
} }
func (q *quicStreamConn) LocalAddr() net.Addr { func (q *quicStreamConn) LocalAddr() net.Addr {