From a56229a365cfe97d9611230b3549decf2beb6429 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Thu, 30 Nov 2023 20:00:24 +0800 Subject: [PATCH] chore: simplify fast open code --- common/net/earlyconn.go | 67 +++++++++++++++++++++++++++ transport/tuic/v4/client.go | 85 ++++++++-------------------------- transport/vmess/httpupgrade.go | 65 -------------------------- transport/vmess/websocket.go | 15 ++++-- 4 files changed, 98 insertions(+), 134 deletions(-) create mode 100644 common/net/earlyconn.go delete mode 100644 transport/vmess/httpupgrade.go diff --git a/common/net/earlyconn.go b/common/net/earlyconn.go new file mode 100644 index 00000000..c9a42819 --- /dev/null +++ b/common/net/earlyconn.go @@ -0,0 +1,67 @@ +package net + +import ( + "net" + "sync" + "sync/atomic" + "unsafe" + + "github.com/metacubex/mihomo/common/buf" +) + +type earlyConn struct { + ExtendedConn // only expose standard N.ExtendedConn function to outside + resFunc func() error + resOnce sync.Once + resErr error +} + +func (conn *earlyConn) Response() error { + conn.resOnce.Do(func() { + conn.resErr = conn.resFunc() + }) + return conn.resErr +} + +func (conn *earlyConn) Read(b []byte) (n int, err error) { + err = conn.Response() + if err != nil { + return 0, err + } + return conn.ExtendedConn.Read(b) +} + +func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) { + err = conn.Response() + if err != nil { + return err + } + return conn.ExtendedConn.ReadBuffer(buffer) +} + +func (conn *earlyConn) Upstream() any { + return conn.ExtendedConn +} + +func (conn *earlyConn) Success() bool { + // atomic visit sync.Once.done + return atomic.LoadUint32((*uint32)(unsafe.Pointer(&conn.resOnce))) == 1 && conn.resErr == nil +} + +func (conn *earlyConn) ReaderReplaceable() bool { + return conn.Success() +} + +func (conn *earlyConn) ReaderPossiblyReplaceable() bool { + return !conn.Success() +} + +func (conn *earlyConn) WriterReplaceable() bool { + return true +} + +var _ ExtendedConn = (*earlyConn)(nil) + +func NewEarlyConn(c net.Conn, f func() error) net.Conn { + return &earlyConn{ExtendedConn: NewExtendedConn(c), resFunc: f} +} diff --git a/transport/tuic/v4/client.go b/transport/tuic/v4/client.go index a553db82..a8474376 100644 --- a/transport/tuic/v4/client.go +++ b/transport/tuic/v4/client.go @@ -11,10 +11,8 @@ import ( "sync" "sync/atomic" "time" - "unsafe" atomic2 "github.com/metacubex/mihomo/common/atomic" - "github.com/metacubex/mihomo/common/buf" N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/common/pool" C "github.com/metacubex/mihomo/constant" @@ -329,75 +327,30 @@ func (t *clientImpl) DialContextWithDialer(ctx context.Context, metadata *C.Meta } bufConn := N.NewBufferedConn(stream) - conn := &earlyConn{ExtendedConn: bufConn, bufConn: bufConn, RequestTimeout: t.RequestTimeout} - if !t.FastOpen { - err = conn.Response() - if err != nil { - return nil, err + response := func() error { + if t.RequestTimeout > 0 { + _ = bufConn.SetReadDeadline(time.Now().Add(t.RequestTimeout)) } + response, err := ReadResponse(bufConn) + if err != nil { + _ = bufConn.Close() + return err + } + if response.IsFailed() { + _ = bufConn.Close() + return errors.New("connect failed") + } + _ = bufConn.SetReadDeadline(time.Time{}) + return nil } - return conn, nil -} - -type earlyConn struct { - N.ExtendedConn // only expose standard N.ExtendedConn function to outside - bufConn *N.BufferedConn - resOnce sync.Once - resErr error - - RequestTimeout time.Duration -} - -func (conn *earlyConn) response() error { - if conn.RequestTimeout > 0 { - _ = conn.SetReadDeadline(time.Now().Add(conn.RequestTimeout)) + if t.FastOpen { + return N.NewEarlyConn(bufConn, response), nil } - response, err := ReadResponse(conn.bufConn) + err = response() if err != nil { - _ = conn.Close() - return err + return nil, err } - if response.IsFailed() { - _ = conn.Close() - return errors.New("connect failed") - } - _ = conn.SetReadDeadline(time.Time{}) - return nil -} - -func (conn *earlyConn) Response() error { - conn.resOnce.Do(func() { - conn.resErr = conn.response() - }) - return conn.resErr -} - -func (conn *earlyConn) Read(b []byte) (n int, err error) { - err = conn.Response() - if err != nil { - return 0, err - } - return conn.bufConn.Read(b) -} - -func (conn *earlyConn) ReadBuffer(buffer *buf.Buffer) (err error) { - err = conn.Response() - if err != nil { - return err - } - return conn.bufConn.ReadBuffer(buffer) -} - -func (conn *earlyConn) Upstream() any { - return conn.bufConn -} - -func (conn *earlyConn) ReaderReplaceable() bool { - return atomic.LoadUint32((*uint32)(unsafe.Pointer(&conn.resOnce))) == 1 && conn.resErr == nil -} - -func (conn *earlyConn) WriterReplaceable() bool { - return true + return bufConn, nil } func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Metadata, dialer C.Dialer, dialFn common.DialFunc) (net.PacketConn, error) { diff --git a/transport/vmess/httpupgrade.go b/transport/vmess/httpupgrade.go deleted file mode 100644 index f7e819db..00000000 --- a/transport/vmess/httpupgrade.go +++ /dev/null @@ -1,65 +0,0 @@ -package vmess - -import ( - "fmt" - "net/http" - "strings" - "sync" - - "github.com/metacubex/mihomo/common/buf" - "github.com/metacubex/mihomo/common/net" -) - -type httpUpgradeEarlyConn struct { - *net.BufferedConn - create sync.Once - done bool - err error -} - -func (c *httpUpgradeEarlyConn) readResponse() { - var request http.Request - response, err := http.ReadResponse(c.Reader(), &request) - c.done = true - if err != nil { - c.err = err - return - } - if response.StatusCode != http.StatusSwitchingProtocols || - !strings.EqualFold(response.Header.Get("Connection"), "upgrade") || - !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { - c.err = fmt.Errorf("unexpected status: %s", response.Status) - return - } -} - -func (c *httpUpgradeEarlyConn) Read(p []byte) (int, error) { - c.create.Do(c.readResponse) - if c.err != nil { - return 0, c.err - } - return c.BufferedConn.Read(p) -} - -func (c *httpUpgradeEarlyConn) ReadBuffer(buffer *buf.Buffer) error { - c.create.Do(c.readResponse) - if c.err != nil { - return c.err - } - return c.BufferedConn.ReadBuffer(buffer) -} - -func (c *httpUpgradeEarlyConn) ReaderReplaceable() bool { - return c.done -} - -func (c *httpUpgradeEarlyConn) ReaderPossiblyReplaceable() bool { - return !c.done -} - -func (c *httpUpgradeEarlyConn) ReadCached() *buf.Buffer { - if c.done { - return c.BufferedConn.ReadCached() - } - return nil -} diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index e898400c..43faac5a 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -418,9 +418,18 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, bufferedConn := N.NewBufferedConn(conn) if c.V2rayHttpUpgrade && c.V2rayHttpUpgradeFastOpen { - return &httpUpgradeEarlyConn{ - BufferedConn: bufferedConn, - }, nil + return N.NewEarlyConn(bufferedConn, func() error { + response, err := http.ReadResponse(bufferedConn.Reader(), request) + if err != nil { + return err + } + if response.StatusCode != http.StatusSwitchingProtocols || + !strings.EqualFold(response.Header.Get("Connection"), "upgrade") || + !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") { + return fmt.Errorf("unexpected status: %s", response.Status) + } + return nil + }), nil } response, err := http.ReadResponse(bufferedConn.Reader(), request)