From 9ff246b29d148d718a94f3fc0e15f01d8d8cb705 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 19 May 2023 23:29:43 +0800 Subject: [PATCH] chore: better packet deadline --- common/net/deadline/packet.go | 109 ++++++++++---------------- common/net/deadline/packet_enhance.go | 73 +++++++++++++++++ common/net/deadline/packet_sing.go | 96 +++++++++++++++++++++++ common/net/packet.go | 4 +- common/net/packet/packet.go | 6 +- common/net/packet/packet_sing.go | 5 ++ 6 files changed, 222 insertions(+), 71 deletions(-) create mode 100644 common/net/deadline/packet_enhance.go create mode 100644 common/net/deadline/packet_sing.go diff --git a/common/net/deadline/packet.go b/common/net/deadline/packet.go index 38b5f579..f68aadaf 100644 --- a/common/net/deadline/packet.go +++ b/common/net/deadline/packet.go @@ -11,12 +11,13 @@ import ( type readResult struct { data []byte - put func() addr net.Addr err error + enhanceReadResult + singReadResult } -type PacketConn struct { +type NetPacketConn struct { net.PacketConn deadline atomic.TypedValue[time.Time] pipeDeadline pipeDeadline @@ -25,23 +26,45 @@ type PacketConn struct { resultCh chan *readResult } -func NewPacketConn(pc net.PacketConn) net.PacketConn { - c := &PacketConn{ +func NewNetPacketConn(pc net.PacketConn) net.PacketConn { + npc := &NetPacketConn{ PacketConn: pc, pipeDeadline: makePipeDeadline(), resultCh: make(chan *readResult, 1), } - c.resultCh <- nil - if enhancePacketConn, isEnhance := pc.(packet.EnhancePacketConn); isEnhance { - return &EnhancePacketConn{ - PacketConn: c, - enhancePacketConn: enhancePacketConn, + npc.resultCh <- nil + if enhancePC, isEnhance := pc.(packet.EnhancePacketConn); isEnhance { + epc := &EnhancePacketConn{ + NetPacketConn: npc, + enhancePacketConn: enhancePacketConn{ + netPacketConn: npc, + enhancePacketConn: enhancePC, + }, + } + if singPC, isSingPC := pc.(packet.SingPacketConn); isSingPC { + return &EnhanceSingPacketConn{ + EnhancePacketConn: epc, + singPacketConn: singPacketConn{ + netPacketConn: npc, + singPacketConn: singPC, + }, + } + } + return epc + } + if singPC, isSingPC := pc.(packet.SingPacketConn); isSingPC { + return &SingPacketConn{ + NetPacketConn: npc, + singPacketConn: singPacketConn{ + netPacketConn: npc, + singPacketConn: singPC, + }, } } - return c + return npc } -func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { +func (c *NetPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { select { case result := <-c.resultCh: if result != nil { @@ -73,7 +96,7 @@ func (c *PacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { return c.ReadFrom(p) } -func (c *PacketConn) pipeReadFrom(size int) { +func (c *NetPacketConn) pipeReadFrom(size int) { buffer := make([]byte, size) n, addr, err := c.PacketConn.ReadFrom(buffer) buffer = buffer[:n] @@ -84,7 +107,7 @@ func (c *PacketConn) pipeReadFrom(size int) { } } -func (c *PacketConn) SetReadDeadline(t time.Time) error { +func (c *NetPacketConn) SetReadDeadline(t time.Time) error { if c.disablePipe.Load() { return c.PacketConn.SetReadDeadline(t) } else if c.inRead.Load() { @@ -96,7 +119,7 @@ func (c *PacketConn) SetReadDeadline(t time.Time) error { return nil } -func (c *PacketConn) ReaderReplaceable() bool { +func (c *NetPacketConn) ReaderReplaceable() bool { select { case result := <-c.resultCh: c.resultCh <- result @@ -111,66 +134,14 @@ func (c *PacketConn) ReaderReplaceable() bool { return c.disablePipe.Load() || c.deadline.Load().IsZero() } -func (c *PacketConn) WriterReplaceable() bool { +func (c *NetPacketConn) WriterReplaceable() bool { return true } -func (c *PacketConn) Upstream() any { +func (c *NetPacketConn) Upstream() any { return c.PacketConn } -func (c *PacketConn) NeedAdditionalReadDeadline() bool { +func (c *NetPacketConn) NeedAdditionalReadDeadline() bool { return false } - -type EnhancePacketConn struct { - *PacketConn - enhancePacketConn packet.EnhancePacketConn -} - -func NewEnhancePacketConn(pc packet.EnhancePacketConn) packet.EnhancePacketConn { - return NewPacketConn(pc).(packet.EnhancePacketConn) -} - -func (c *EnhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { - select { - case result := <-c.resultCh: - if result != nil { - data = result.data - put = result.put - addr = result.addr - err = result.err - c.resultCh <- nil // finish cache read - return - } else { - c.resultCh <- nil - break - } - case <-c.pipeDeadline.wait(): - return nil, nil, nil, os.ErrDeadlineExceeded - } - - if c.disablePipe.Load() { - return c.enhancePacketConn.WaitReadFrom() - } else if c.deadline.Load().IsZero() { - c.inRead.Store(true) - defer c.inRead.Store(false) - data, put, addr, err = c.enhancePacketConn.WaitReadFrom() - return - } - - <-c.resultCh - go c.pipeWaitReadFrom() - - return c.WaitReadFrom() -} - -func (c *EnhancePacketConn) pipeWaitReadFrom() { - data, put, addr, err := c.enhancePacketConn.WaitReadFrom() - c.resultCh <- &readResult{ - data: data, - put: put, - addr: addr, - err: err, - } -} diff --git a/common/net/deadline/packet_enhance.go b/common/net/deadline/packet_enhance.go new file mode 100644 index 00000000..589e1447 --- /dev/null +++ b/common/net/deadline/packet_enhance.go @@ -0,0 +1,73 @@ +package deadline + +import ( + "net" + "os" + + "github.com/Dreamacro/clash/common/net/packet" +) + +type EnhancePacketConn struct { + *NetPacketConn + enhancePacketConn +} + +var _ packet.EnhancePacketConn = (*EnhancePacketConn)(nil) + +func NewEnhancePacketConn(pc packet.EnhancePacketConn) packet.EnhancePacketConn { + return NewNetPacketConn(pc).(packet.EnhancePacketConn) +} + +type enhanceReadResult struct { + put func() +} + +type enhancePacketConn struct { + netPacketConn *NetPacketConn + enhancePacketConn packet.EnhancePacketConn +} + +func (c *enhancePacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) { + select { + case result := <-c.netPacketConn.resultCh: + if result != nil { + data = result.data + put = result.put + addr = result.addr + err = result.err + c.netPacketConn.resultCh <- nil // finish cache read + return + } else { + c.netPacketConn.resultCh <- nil + break + } + case <-c.netPacketConn.pipeDeadline.wait(): + return nil, nil, nil, os.ErrDeadlineExceeded + } + + if c.netPacketConn.disablePipe.Load() { + return c.enhancePacketConn.WaitReadFrom() + } else if c.netPacketConn.deadline.Load().IsZero() { + c.netPacketConn.inRead.Store(true) + defer c.netPacketConn.inRead.Store(false) + data, put, addr, err = c.enhancePacketConn.WaitReadFrom() + return + } + + <-c.netPacketConn.resultCh + go c.pipeWaitReadFrom() + + return c.WaitReadFrom() +} + +func (c *enhancePacketConn) pipeWaitReadFrom() { + data, put, addr, err := c.enhancePacketConn.WaitReadFrom() + c.netPacketConn.resultCh <- &readResult{ + data: data, + enhanceReadResult: enhanceReadResult{ + put: put, + }, + addr: addr, + err: err, + } +} diff --git a/common/net/deadline/packet_sing.go b/common/net/deadline/packet_sing.go new file mode 100644 index 00000000..f69022ab --- /dev/null +++ b/common/net/deadline/packet_sing.go @@ -0,0 +1,96 @@ +package deadline + +import ( + "os" + + "github.com/Dreamacro/clash/common/net/packet" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" +) + +type SingPacketConn struct { + *NetPacketConn + singPacketConn +} + +var _ packet.SingPacketConn = (*SingPacketConn)(nil) + +func NewSingPacketConn(pc packet.SingPacketConn) packet.SingPacketConn { + return NewNetPacketConn(pc).(packet.SingPacketConn) +} + +type EnhanceSingPacketConn struct { + *EnhancePacketConn + singPacketConn +} + +func NewEnhanceSingPacketConn(pc packet.EnhanceSingPacketConn) packet.EnhanceSingPacketConn { + return NewNetPacketConn(pc).(packet.EnhanceSingPacketConn) +} + +var _ packet.EnhanceSingPacketConn = (*EnhanceSingPacketConn)(nil) + +type singReadResult struct { + buffer *buf.Buffer + destination M.Socksaddr +} + +type singPacketConn struct { + netPacketConn *NetPacketConn + singPacketConn packet.SingPacketConn +} + +func (c *singPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + select { + case result := <-c.netPacketConn.resultCh: + if result != nil { + destination = result.destination + err = result.err + buffer.Resize(result.buffer.Start(), 0) + n := copy(buffer.FreeBytes(), result.buffer.Bytes()) + buffer.Truncate(n) + result.buffer.Advance(n) + if result.buffer.IsEmpty() { + result.buffer.Release() + } + c.netPacketConn.resultCh <- nil // finish cache read + return + } else { + c.netPacketConn.resultCh <- nil + break + } + case <-c.netPacketConn.pipeDeadline.wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded + } + + if c.netPacketConn.disablePipe.Load() { + return c.singPacketConn.ReadPacket(buffer) + } else if c.netPacketConn.deadline.Load().IsZero() { + c.netPacketConn.inRead.Store(true) + defer c.netPacketConn.inRead.Store(false) + destination, err = c.singPacketConn.ReadPacket(buffer) + return + } + + <-c.netPacketConn.resultCh + go c.pipeReadPacket(buffer.Cap(), buffer.Start()) + + return c.ReadPacket(buffer) +} + +func (c *singPacketConn) pipeReadPacket(bufLen int, bufStart int) { + buffer := buf.NewSize(bufLen) + buffer.Advance(bufStart) + destination, err := c.singPacketConn.ReadPacket(buffer) + c.netPacketConn.resultCh <- &readResult{ + singReadResult: singReadResult{ + buffer: buffer, + destination: destination, + }, + err: err, + } +} + +func (c *singPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return c.singPacketConn.WritePacket(buffer, destination) +} diff --git a/common/net/packet.go b/common/net/packet.go index d01c9efe..865590ce 100644 --- a/common/net/packet.go +++ b/common/net/packet.go @@ -11,8 +11,10 @@ import ( type EnhancePacketConn = packet.EnhancePacketConn var NewEnhancePacketConn = packet.NewEnhancePacketConn -var NewDeadlinePacketConn = deadline.NewPacketConn +var NewDeadlineNetPacketConn = deadline.NewNetPacketConn var NewDeadlineEnhancePacketConn = deadline.NewEnhancePacketConn +var NewDeadlineSingPacketConn = deadline.NewSingPacketConn +var NewDeadlineEnhanceSingPacketConn = deadline.NewEnhanceSingPacketConn type threadSafePacketConn struct { EnhancePacketConn diff --git a/common/net/packet/packet.go b/common/net/packet/packet.go index 6329803b..6c9542c1 100644 --- a/common/net/packet/packet.go +++ b/common/net/packet/packet.go @@ -6,9 +6,13 @@ import ( "github.com/Dreamacro/clash/common/pool" ) +type WaitReadFrom interface { + WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) +} + type EnhancePacketConn interface { net.PacketConn - WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) + WaitReadFrom } func NewEnhancePacketConn(pc net.PacketConn) EnhancePacketConn { diff --git a/common/net/packet/packet_sing.go b/common/net/packet/packet_sing.go index 9be1a4a1..daa352c8 100644 --- a/common/net/packet/packet_sing.go +++ b/common/net/packet/packet_sing.go @@ -11,6 +11,11 @@ import ( type SingPacketConn = N.NetPacketConn +type EnhanceSingPacketConn interface { + N.NetPacketConn + EnhancePacketConn +} + type enhanceSingPacketConn struct { N.NetPacketConn readWaiter N.PacketReadWaiter