sing-box/common/dialer/tfo.go

160 lines
3.1 KiB
Go
Raw Normal View History

2023-08-16 17:47:24 +08:00
//go:build go1.20
2022-10-05 21:02:44 +08:00
package dialer
import (
"context"
"io"
"net"
"os"
2023-09-09 19:51:10 +08:00
"sync"
2022-10-05 21:02:44 +08:00
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
2022-10-18 13:28:03 +08:00
E "github.com/sagernet/sing/common/exceptions"
2022-10-05 21:02:44 +08:00
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
2023-02-07 18:06:25 +08:00
"github.com/sagernet/tfo-go"
2022-10-05 21:02:44 +08:00
)
type slowOpenConn struct {
dialer *tfo.Dialer
ctx context.Context
network string
destination M.Socksaddr
conn net.Conn
create chan struct{}
2023-09-09 19:51:10 +08:00
access sync.Mutex
2022-10-05 21:02:44 +08:00
err error
}
2023-08-16 17:47:24 +08:00
func DialSlowContext(dialer *tcpDialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
2022-10-05 21:02:44 +08:00
if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP {
2023-06-07 20:28:21 +08:00
switch N.NetworkName(network) {
case N.NetworkTCP, N.NetworkUDP:
return dialer.Dialer.DialContext(ctx, network, destination.String())
default:
return dialer.Dialer.DialContext(ctx, network, destination.AddrString())
}
2022-10-05 21:02:44 +08:00
}
return &slowOpenConn{
dialer: dialer,
ctx: ctx,
network: network,
destination: destination,
create: make(chan struct{}),
}, nil
}
func (c *slowOpenConn) Read(b []byte) (n int, err error) {
if c.conn == nil {
select {
case <-c.create:
if c.err != nil {
return 0, c.err
}
case <-c.ctx.Done():
return 0, c.ctx.Err()
}
}
return c.conn.Read(b)
}
func (c *slowOpenConn) Write(b []byte) (n int, err error) {
2023-09-09 19:51:10 +08:00
if c.conn != nil {
return c.conn.Write(b)
}
c.access.Lock()
defer c.access.Unlock()
select {
case <-c.create:
if c.err != nil {
return 0, c.err
2022-10-18 13:28:03 +08:00
}
2023-09-09 19:51:10 +08:00
return c.conn.Write(b)
default:
}
c.conn, err = c.dialer.DialContext(c.ctx, c.network, c.destination.String(), b)
if err != nil {
c.conn = nil
c.err = E.Cause(err, "dial tcp fast open")
2022-10-05 21:02:44 +08:00
}
2024-01-17 10:10:36 +08:00
n = len(b)
2023-09-09 19:51:10 +08:00
close(c.create)
return
2022-10-05 21:02:44 +08:00
}
func (c *slowOpenConn) Close() error {
return common.Close(c.conn)
}
func (c *slowOpenConn) LocalAddr() net.Addr {
if c.conn == nil {
return M.Socksaddr{}
}
return c.conn.LocalAddr()
}
func (c *slowOpenConn) RemoteAddr() net.Addr {
if c.conn == nil {
return M.Socksaddr{}
}
return c.conn.RemoteAddr()
}
func (c *slowOpenConn) SetDeadline(t time.Time) error {
if c.conn == nil {
return os.ErrInvalid
}
return c.conn.SetDeadline(t)
}
func (c *slowOpenConn) SetReadDeadline(t time.Time) error {
if c.conn == nil {
return os.ErrInvalid
}
return c.conn.SetReadDeadline(t)
}
func (c *slowOpenConn) SetWriteDeadline(t time.Time) error {
if c.conn == nil {
return os.ErrInvalid
}
return c.conn.SetWriteDeadline(t)
}
func (c *slowOpenConn) Upstream() any {
return c.conn
}
func (c *slowOpenConn) ReaderReplaceable() bool {
return c.conn != nil
}
func (c *slowOpenConn) WriterReplaceable() bool {
return c.conn != nil
}
2022-10-10 13:31:45 +08:00
func (c *slowOpenConn) LazyHeadroom() bool {
return c.conn == nil
}
2023-03-23 17:14:38 +08:00
func (c *slowOpenConn) NeedHandshake() bool {
return c.conn == nil
}
2022-10-05 21:02:44 +08:00
func (c *slowOpenConn) WriteTo(w io.Writer) (n int64, err error) {
if c.conn == nil {
select {
case <-c.create:
if c.err != nil {
return 0, c.err
}
case <-c.ctx.Done():
return 0, c.ctx.Err()
}
}
return bufio.Copy(w, c.conn)
}