sing-box/transport/trojan/mux.go

85 lines
2.2 KiB
Go
Raw Normal View History

package trojan
import (
2024-06-24 09:49:15 +08:00
std_bufio "bufio"
"context"
"net"
2024-11-23 22:34:02 +08:00
"os"
2024-06-24 09:49:15 +08:00
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
2024-11-02 00:39:02 +08:00
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
2024-11-23 22:34:02 +08:00
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/task"
"github.com/sagernet/smux"
)
2024-11-23 22:34:02 +08:00
func HandleMuxConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler, logger logger.ContextLogger, onClose N.CloseHandlerFunc) error {
session, err := smux.Server(conn, smuxConfig())
if err != nil {
return err
}
var group task.Group
group.Append0(func(_ context.Context) error {
var stream net.Conn
for {
stream, err = session.AcceptStream()
if err != nil {
return err
}
2024-11-23 22:34:02 +08:00
go newMuxConnection(ctx, stream, source, handler, logger)
}
})
group.Cleanup(func() {
session.Close()
2024-11-23 22:34:02 +08:00
if onClose != nil {
onClose(os.ErrClosed)
}
})
return group.Run(ctx)
}
2024-11-23 22:34:02 +08:00
func newMuxConnection(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler, logger logger.ContextLogger) {
err := newMuxConnection0(ctx, conn, source, handler)
if err != nil {
2024-11-02 00:39:02 +08:00
logger.ErrorContext(ctx, E.Cause(err, "process trojan-go multiplex connection"))
}
}
2024-11-23 22:34:02 +08:00
func newMuxConnection0(ctx context.Context, conn net.Conn, source M.Socksaddr, handler Handler) error {
2024-06-24 09:49:15 +08:00
reader := std_bufio.NewReader(conn)
command, err := reader.ReadByte()
if err != nil {
return E.Cause(err, "read command")
}
2024-11-23 22:34:02 +08:00
destination, err := M.SocksaddrSerializer.ReadAddrPort(reader)
if err != nil {
return E.Cause(err, "read destination")
}
2024-06-24 09:49:15 +08:00
if reader.Buffered() > 0 {
buffer := buf.NewSize(reader.Buffered())
_, err = buffer.ReadFullFrom(reader, buffer.Len())
if err != nil {
return err
}
conn = bufio.NewCachedConn(conn, buffer)
}
switch command {
case CommandTCP:
2024-11-23 22:34:02 +08:00
handler.NewConnectionEx(ctx, conn, source, destination, nil)
case CommandUDP:
2024-11-23 22:34:02 +08:00
handler.NewPacketConnectionEx(ctx, &PacketConn{Conn: conn}, source, destination, nil)
default:
return E.New("unknown command ", command)
}
2024-11-23 22:34:02 +08:00
return nil
}
func smuxConfig() *smux.Config {
config := smux.DefaultConfig()
config.KeepAliveDisabled = true
return config
}