diff --git a/common/net/deadline/pipe_sing.go b/common/net/deadline/pipe_sing.go new file mode 100644 index 00000000..20721fad --- /dev/null +++ b/common/net/deadline/pipe_sing.go @@ -0,0 +1,217 @@ +package deadline + +import ( + "io" + "net" + "os" + "sync" + "time" + + "github.com/sagernet/sing/common/buf" + N "github.com/sagernet/sing/common/network" +) + +type pipeAddr struct{} + +func (pipeAddr) Network() string { return "pipe" } +func (pipeAddr) String() string { return "pipe" } + +type pipe struct { + wrMu sync.Mutex // Serialize Write operations + + // Used by local Read to interact with remote Write. + // Successful receive on rdRx is always followed by send on rdTx. + rdRx <-chan []byte + rdTx chan<- int + + // Used by local Write to interact with remote Read. + // Successful send on wrTx is always followed by receive on wrRx. + wrTx chan<- []byte + wrRx <-chan int + + once sync.Once // Protects closing localDone + localDone chan struct{} + remoteDone <-chan struct{} + + readDeadline pipeDeadline + writeDeadline pipeDeadline + + readWaitOptions N.ReadWaitOptions +} + +// Pipe creates a synchronous, in-memory, full duplex +// network connection; both ends implement the Conn interface. +// Reads on one end are matched with writes on the other, +// copying data directly between the two; there is no internal +// buffering. +func Pipe() (net.Conn, net.Conn) { + cb1 := make(chan []byte) + cb2 := make(chan []byte) + cn1 := make(chan int) + cn2 := make(chan int) + done1 := make(chan struct{}) + done2 := make(chan struct{}) + + p1 := &pipe{ + rdRx: cb1, rdTx: cn1, + wrTx: cb2, wrRx: cn2, + localDone: done1, remoteDone: done2, + readDeadline: makePipeDeadline(), + writeDeadline: makePipeDeadline(), + } + p2 := &pipe{ + rdRx: cb2, rdTx: cn2, + wrTx: cb1, wrRx: cn1, + localDone: done2, remoteDone: done1, + readDeadline: makePipeDeadline(), + writeDeadline: makePipeDeadline(), + } + return p1, p2 +} + +func (*pipe) LocalAddr() net.Addr { return pipeAddr{} } +func (*pipe) RemoteAddr() net.Addr { return pipeAddr{} } + +func (p *pipe) Read(b []byte) (int, error) { + n, err := p.read(b) + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + err = &net.OpError{Op: "read", Net: "pipe", Err: err} + } + return n, err +} + +func (p *pipe) read(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.EOF + case isClosedChan(p.readDeadline.wait()): + return 0, os.ErrDeadlineExceeded + } + + select { + case bw := <-p.rdRx: + nr := copy(b, bw) + p.rdTx <- nr + return nr, nil + case <-p.localDone: + return 0, io.ErrClosedPipe + case <-p.remoteDone: + return 0, io.EOF + case <-p.readDeadline.wait(): + return 0, os.ErrDeadlineExceeded + } +} + +func (p *pipe) Write(b []byte) (int, error) { + n, err := p.write(b) + if err != nil && err != io.ErrClosedPipe { + err = &net.OpError{Op: "write", Net: "pipe", Err: err} + } + return n, err +} + +func (p *pipe) write(b []byte) (n int, err error) { + switch { + case isClosedChan(p.localDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return 0, io.ErrClosedPipe + case isClosedChan(p.writeDeadline.wait()): + return 0, os.ErrDeadlineExceeded + } + + p.wrMu.Lock() // Ensure entirety of b is written together + defer p.wrMu.Unlock() + for once := true; once || len(b) > 0; once = false { + select { + case p.wrTx <- b: + nw := <-p.wrRx + b = b[nw:] + n += nw + case <-p.localDone: + return n, io.ErrClosedPipe + case <-p.remoteDone: + return n, io.ErrClosedPipe + case <-p.writeDeadline.wait(): + return n, os.ErrDeadlineExceeded + } + } + return n, nil +} + +func (p *pipe) SetDeadline(t time.Time) error { + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.readDeadline.set(t) + p.writeDeadline.set(t) + return nil +} + +func (p *pipe) SetReadDeadline(t time.Time) error { + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.readDeadline.set(t) + return nil +} + +func (p *pipe) SetWriteDeadline(t time.Time) error { + if isClosedChan(p.localDone) || isClosedChan(p.remoteDone) { + return io.ErrClosedPipe + } + p.writeDeadline.set(t) + return nil +} + +func (p *pipe) Close() error { + p.once.Do(func() { close(p.localDone) }) + return nil +} + +var _ N.ReadWaiter = (*pipe)(nil) + +func (p *pipe) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + p.readWaitOptions = options + return false +} + +func (p *pipe) WaitReadBuffer() (buffer *buf.Buffer, err error) { + buffer, err = p.waitReadBuffer() + if err != nil && err != io.EOF && err != io.ErrClosedPipe { + err = &net.OpError{Op: "read", Net: "pipe", Err: err} + } + return +} + +func (p *pipe) waitReadBuffer() (buffer *buf.Buffer, err error) { + switch { + case isClosedChan(p.localDone): + return nil, io.ErrClosedPipe + case isClosedChan(p.remoteDone): + return nil, io.EOF + case isClosedChan(p.readDeadline.wait()): + return nil, os.ErrDeadlineExceeded + } + select { + case bw := <-p.rdRx: + buffer = p.readWaitOptions.NewBuffer() + var nr int + nr, err = buffer.Write(bw) + if err != nil { + buffer.Release() + return + } + p.readWaitOptions.PostReturn(buffer) + p.rdTx <- nr + return + case <-p.localDone: + return nil, io.ErrClosedPipe + case <-p.remoteDone: + return nil, io.EOF + case <-p.readDeadline.wait(): + return nil, os.ErrDeadlineExceeded + } +} diff --git a/common/net/sing.go b/common/net/sing.go index f8698620..3296ad5b 100644 --- a/common/net/sing.go +++ b/common/net/sing.go @@ -35,6 +35,8 @@ func NeedHandshake(conn any) bool { type CountFunc = network.CountFunc +var Pipe = deadline.Pipe + // Relay copies between left and right bidirectionally. func Relay(leftConn, rightConn net.Conn) { defer runtime.KeepAlive(leftConn) diff --git a/listener/http/client.go b/listener/http/client.go index c35cadad..dfd1985f 100644 --- a/listener/http/client.go +++ b/listener/http/client.go @@ -8,6 +8,7 @@ import ( "time" "github.com/metacubex/mihomo/adapter/inbound" + N "github.com/metacubex/mihomo/common/net" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/transport/socks5" ) @@ -30,7 +31,7 @@ func newClient(srcConn net.Conn, tunnel C.Tunnel, additions ...inbound.Addition) return nil, socks5.ErrAddressNotSupported } - left, right := net.Pipe() + left, right := N.Pipe() go tunnel.HandleTCPConn(inbound.NewHTTP(dstAddr, srcConn, right, additions...)) diff --git a/listener/http/upgrade.go b/listener/http/upgrade.go index 8a6291d1..ac67ef68 100644 --- a/listener/http/upgrade.go +++ b/listener/http/upgrade.go @@ -41,7 +41,7 @@ func handleUpgrade(conn net.Conn, request *http.Request, tunnel C.Tunnel, additi return } - left, right := net.Pipe() + left, right := N.Pipe() go tunnel.HandleTCPConn(inbound.NewHTTP(dstAddr, conn, right, additions...)) diff --git a/listener/inner/tcp.go b/listener/inner/tcp.go index 373fd2b4..dac3d721 100644 --- a/listener/inner/tcp.go +++ b/listener/inner/tcp.go @@ -6,6 +6,7 @@ import ( "net/netip" "strconv" + N "github.com/metacubex/mihomo/common/net" C "github.com/metacubex/mihomo/constant" ) @@ -20,7 +21,7 @@ func HandleTcp(address string) (conn net.Conn, err error) { return nil, errors.New("tcp uninitialized") } // executor Parsed - conn1, conn2 := net.Pipe() + conn1, conn2 := N.Pipe() metadata := &C.Metadata{} metadata.NetWork = C.TCP diff --git a/transport/tuic/v4/client.go b/transport/tuic/v4/client.go index 67906959..5c9c889c 100644 --- a/transport/tuic/v4/client.go +++ b/transport/tuic/v4/client.go @@ -364,7 +364,7 @@ func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Met return nil, common.TooManyOpenStreams } - pipe1, pipe2 := net.Pipe() + pipe1, pipe2 := N.Pipe() var connId uint32 for { connId = fastrand.Uint32() diff --git a/transport/tuic/v5/client.go b/transport/tuic/v5/client.go index 8a4d6fb1..89454add 100644 --- a/transport/tuic/v5/client.go +++ b/transport/tuic/v5/client.go @@ -348,7 +348,7 @@ func (t *clientImpl) ListenPacketWithDialer(ctx context.Context, metadata *C.Met return nil, common.TooManyOpenStreams } - pipe1, pipe2 := net.Pipe() + pipe1, pipe2 := N.Pipe() var connId uint16 for { connId = uint16(fastrand.Intn(0xFFFF)) diff --git a/tunnel/connection.go b/tunnel/connection.go index 33cc4e8d..e96545e8 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -82,6 +82,6 @@ func closeAllLocalCoon(lAddr string) { }) } -func handleSocket(ctx C.ConnContext, outbound net.Conn) { - N.Relay(ctx.Conn(), outbound) +func handleSocket(inbound, outbound net.Conn) { + N.Relay(inbound, outbound) } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 391fe7c1..18c1eb09 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -584,7 +584,7 @@ func handleTCPConn(connCtx C.ConnContext) { peekMutex.Lock() defer peekMutex.Unlock() _ = conn.SetReadDeadline(time.Time{}) // reset - handleSocket(connCtx, remoteConn) + handleSocket(conn, remoteConn) } func shouldResolveIP(rule C.Rule, metadata *C.Metadata) bool {