diff --git a/outbound/default.go b/outbound/default.go index 73885a43..9569d6ab 100644 --- a/outbound/default.go +++ b/outbound/default.go @@ -70,6 +70,28 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a return CopyEarlyConn(ctx, conn, outConn) } +func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) + var outConn net.Conn + var err error + if len(metadata.DestinationAddresses) > 0 { + outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) + } else if metadata.Destination.IsFqdn() { + var destinationAddresses []netip.Addr + destinationAddresses, err = router.LookupDefault(ctx, metadata.Destination.Fqdn) + if err != nil { + return N.HandshakeFailure(conn, err) + } + outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, destinationAddresses) + } else { + outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) + } + if err != nil { + return N.HandshakeFailure(conn, err) + } + return CopyEarlyConn(ctx, conn, outConn) +} + func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { ctx = adapter.WithContext(ctx, &metadata) var outConn net.PacketConn @@ -99,6 +121,42 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) } +func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = adapter.WithContext(ctx, &metadata) + var outConn net.PacketConn + var destinationAddress netip.Addr + var err error + if len(metadata.DestinationAddresses) > 0 { + outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) + } else if metadata.Destination.IsFqdn() { + var destinationAddresses []netip.Addr + destinationAddresses, err = router.LookupDefault(ctx, metadata.Destination.Fqdn) + if err != nil { + return N.HandshakeFailure(conn, err) + } + outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses) + } else { + outConn, err = this.ListenPacket(ctx, metadata.Destination) + } + if err != nil { + return N.HandshakeFailure(conn, err) + } + if destinationAddress.IsValid() { + if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { + natConn.UpdateDestination(destinationAddress) + } + } + switch metadata.Protocol { + case C.ProtocolSTUN: + ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout) + case C.ProtocolQUIC: + ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout) + case C.ProtocolDNS: + ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout) + } + return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) +} + func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error { if cachedReader, isCached := conn.(N.CachedReader); isCached { payload := cachedReader.ReadCached() diff --git a/outbound/socks.go b/outbound/socks.go index 48107480..58f4d7f9 100644 --- a/outbound/socks.go +++ b/outbound/socks.go @@ -80,11 +80,11 @@ func (h *Socks) DialContext(ctx context.Context, network string, destination M.S return nil, E.Extend(N.ErrUnknownNetwork, network) } if h.resolve && destination.IsFqdn() { - addrs, err := h.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn) if err != nil { return nil, err } - return N.DialSerial(ctx, h.client, network, destination, addrs) + return N.DialSerial(ctx, h.client, network, destination, destinationAddresses) } return h.client.DialContext(ctx, network, destination) } @@ -97,14 +97,25 @@ func (h *Socks) ListenPacket(ctx context.Context, destination M.Socksaddr) (net. h.logger.InfoContext(ctx, "outbound UoT packet connection to ", destination) return h.uotClient.ListenPacket(ctx, destination) } + if h.resolve && destination.IsFqdn() { + destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn) + if err != nil { + return nil, err + } + packetConn, _, err := N.ListenSerial(ctx, h.client, destination, destinationAddresses) + if err != nil { + return nil, err + } + return packetConn, nil + } h.logger.InfoContext(ctx, "outbound packet connection to ", destination) return h.client.ListenPacket(ctx, destination) } func (h *Socks) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return NewConnection(ctx, h, conn, metadata) + return NewDirectConnection(ctx, h.router, h, conn, metadata) } func (h *Socks) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - return NewPacketConnection(ctx, h, conn, metadata) + return NewDirectPacketConnection(ctx, h.router, h, conn, metadata) } diff --git a/outbound/wireguard.go b/outbound/wireguard.go index bc24e2e7..e141fbb5 100644 --- a/outbound/wireguard.go +++ b/outbound/wireguard.go @@ -202,26 +202,37 @@ func (w *WireGuard) DialContext(ctx context.Context, network string, destination w.logger.InfoContext(ctx, "outbound packet connection to ", destination) } if destination.IsFqdn() { - addrs, err := w.router.LookupDefault(ctx, destination.Fqdn) + destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) if err != nil { return nil, err } - return N.DialSerial(ctx, w.tunDevice, network, destination, addrs) + return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses) } return w.tunDevice.DialContext(ctx, network, destination) } func (w *WireGuard) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { w.logger.InfoContext(ctx, "outbound packet connection to ", destination) + if destination.IsFqdn() { + destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn) + if err != nil { + return nil, err + } + packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses) + if err != nil { + return nil, err + } + return packetConn, err + } return w.tunDevice.ListenPacket(ctx, destination) } func (w *WireGuard) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return NewConnection(ctx, w, conn, metadata) + return NewDirectConnection(ctx, w.router, w, conn, metadata) } func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - return NewPacketConnection(ctx, w, conn, metadata) + return NewDirectPacketConnection(ctx, w.router, w, conn, metadata) } func (w *WireGuard) Start() error {