diff --git a/outbound/direct.go b/outbound/direct.go index b252a771..256af894 100644 --- a/outbound/direct.go +++ b/outbound/direct.go @@ -12,6 +12,8 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-dns" + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -160,8 +162,30 @@ func (h *Direct) ListenPacket(ctx context.Context, destination M.Socksaddr) (net ctx, metadata := adapter.AppendContext(ctx) metadata.Outbound = h.tag metadata.Destination = destination - h.logger.InfoContext(ctx, "outbound packet connection") - return h.dialer.ListenPacket(ctx, destination) + switch h.overrideOption { + case 1: + destination = h.overrideDestination + case 2: + newDestination := h.overrideDestination + newDestination.Port = destination.Port + destination = newDestination + case 3: + destination.Port = h.overrideDestination.Port + } + if h.overrideOption == 0 { + h.logger.InfoContext(ctx, "outbound packet connection") + } else { + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + } + conn, err := h.dialer.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + if h.overrideOption == 0 { + return conn, nil + } else { + return &overridePacketConn{bufio.NewPacketConn(conn), destination}, nil + } } func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { @@ -171,3 +195,20 @@ func (h *Direct) NewConnection(ctx context.Context, conn net.Conn, metadata adap func (h *Direct) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { return NewPacketConnection(ctx, h, conn, metadata) } + +type overridePacketConn struct { + N.NetPacketConn + overrideDestination M.Socksaddr +} + +func (c *overridePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return c.NetPacketConn.WritePacket(buffer, c.overrideDestination) +} + +func (c *overridePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + return c.NetPacketConn.WriteTo(p, c.overrideDestination.UDPAddr()) +} + +func (c *overridePacketConn) Upstream() any { + return c.NetPacketConn +}