From 3980899f3b28775ec86dc44a3a5f3792bdddaba9 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Sat, 1 Apr 2023 20:56:49 +0800 Subject: [PATCH] fix: firstWriteCallBackConn can pass N.ExtendedConn too --- adapter/outboundgroup/fallback.go | 17 ++++---- adapter/outboundgroup/loadbalance.go | 17 ++++---- adapter/outboundgroup/urltest.go | 17 ++++---- common/callback/callback.go | 62 +++++++++++++++++++++++++--- 4 files changed, 78 insertions(+), 35 deletions(-) diff --git a/adapter/outboundgroup/fallback.go b/adapter/outboundgroup/fallback.go index d1d5e6b3..02ba0ac6 100644 --- a/adapter/outboundgroup/fallback.go +++ b/adapter/outboundgroup/fallback.go @@ -37,16 +37,13 @@ func (f *Fallback) DialContext(ctx context.Context, metadata *C.Metadata, opts . } if N.NeedHandshake(c) { - c = &callback.FirstWriteCallBackConn{ - Conn: c, - Callback: func(err error) { - if err == nil { - f.onDialSuccess() - } else { - f.onDialFailed(proxy.Type(), err) - } - }, - } + c = callback.NewFirstWriteCallBackConn(c, func(err error) { + if err == nil { + f.onDialSuccess() + } else { + f.onDialFailed(proxy.Type(), err) + } + }) } return c, err diff --git a/adapter/outboundgroup/loadbalance.go b/adapter/outboundgroup/loadbalance.go index 1ed80496..15e17a13 100644 --- a/adapter/outboundgroup/loadbalance.go +++ b/adapter/outboundgroup/loadbalance.go @@ -95,16 +95,13 @@ func (lb *LoadBalance) DialContext(ctx context.Context, metadata *C.Metadata, op } if N.NeedHandshake(c) { - c = &callback.FirstWriteCallBackConn{ - Conn: c, - Callback: func(err error) { - if err == nil { - lb.onDialSuccess() - } else { - lb.onDialFailed(proxy.Type(), err) - } - }, - } + c = callback.NewFirstWriteCallBackConn(c, func(err error) { + if err == nil { + lb.onDialSuccess() + } else { + lb.onDialFailed(proxy.Type(), err) + } + }) } return diff --git a/adapter/outboundgroup/urltest.go b/adapter/outboundgroup/urltest.go index d340539c..64ce17f6 100644 --- a/adapter/outboundgroup/urltest.go +++ b/adapter/outboundgroup/urltest.go @@ -45,16 +45,13 @@ func (u *URLTest) DialContext(ctx context.Context, metadata *C.Metadata, opts .. } if N.NeedHandshake(c) { - c = &callback.FirstWriteCallBackConn{ - Conn: c, - Callback: func(err error) { - if err == nil { - u.onDialSuccess() - } else { - u.onDialFailed(proxy.Type(), err) - } - }, - } + c = callback.NewFirstWriteCallBackConn(c, func(err error) { + if err == nil { + u.onDialSuccess() + } else { + u.onDialFailed(proxy.Type(), err) + } + }) } return c, err diff --git a/common/callback/callback.go b/common/callback/callback.go index a0f1e717..9d64bb92 100644 --- a/common/callback/callback.go +++ b/common/callback/callback.go @@ -1,25 +1,77 @@ package callback import ( + "github.com/Dreamacro/clash/common/buf" + N "github.com/Dreamacro/clash/common/net" C "github.com/Dreamacro/clash/constant" ) -type FirstWriteCallBackConn struct { +type firstWriteCallBackConn struct { C.Conn - Callback func(error) + callback func(error) written bool } -func (c *FirstWriteCallBackConn) Write(b []byte) (n int, err error) { +func (c *firstWriteCallBackConn) Write(b []byte) (n int, err error) { defer func() { if !c.written { c.written = true - c.Callback(err) + c.callback(err) } }() return c.Conn.Write(b) } -func (c *FirstWriteCallBackConn) Upstream() any { +func (c *firstWriteCallBackConn) Upstream() any { return c.Conn } + +type extendedConn interface { + C.Conn + N.ExtendedConn +} + +type firstWriteCallBackExtendedConn struct { + extendedConn + callback func(error) + written bool +} + +func (c *firstWriteCallBackExtendedConn) Write(b []byte) (n int, err error) { + defer func() { + if !c.written { + c.written = true + c.callback(err) + } + }() + return c.extendedConn.Write(b) +} + +func (c *firstWriteCallBackExtendedConn) WriteBuffer(buffer *buf.Buffer) (err error) { + defer func() { + if !c.written { + c.written = true + c.callback(err) + } + }() + return c.extendedConn.WriteBuffer(buffer) +} + +func (c *firstWriteCallBackExtendedConn) Upstream() any { + return c.extendedConn +} + +func NewFirstWriteCallBackConn(c C.Conn, callback func(error)) C.Conn { + if c, ok := c.(extendedConn); ok { + return &firstWriteCallBackExtendedConn{ + extendedConn: c, + callback: callback, + written: false, + } + } + return &firstWriteCallBackConn{ + Conn: c, + callback: callback, + written: false, + } +}