diff --git a/adapter/outbound/default.go b/adapter/outbound/default.go index 78b9bfd8..bb58ff54 100644 --- a/adapter/outbound/default.go +++ b/adapter/outbound/default.go @@ -20,6 +20,7 @@ import ( ) func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext) error { + defer conn.Close() ctx = adapter.WithContext(ctx, &metadata) var outConn net.Conn var err error @@ -40,6 +41,7 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a } func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error { + defer conn.Close() ctx = adapter.WithContext(ctx, &metadata) var outConn net.Conn var err error @@ -67,29 +69,49 @@ func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dial } func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext) error { + defer conn.Close() ctx = adapter.WithContext(ctx, &metadata) - var outConn net.PacketConn - var destinationAddress netip.Addr - var err error - if len(metadata.DestinationAddresses) > 0 { - outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) + var ( + outPacketConn net.PacketConn + outConn net.Conn + destinationAddress netip.Addr + err error + ) + if metadata.UDPConnect { + if len(metadata.DestinationAddresses) > 0 { + outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) + } else { + outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) + } + if err != nil { + return N.ReportHandshakeFailure(conn, err) + } + outPacketConn = bufio.NewUnbindPacketConn(outConn) + connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr()) + if connRemoteAddr != metadata.Destination.Addr { + destinationAddress = connRemoteAddr + } } else { - outConn, err = this.ListenPacket(ctx, metadata.Destination) + if len(metadata.DestinationAddresses) > 0 { + outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) + } else { + outPacketConn, err = this.ListenPacket(ctx, metadata.Destination) + } + if err != nil { + return N.ReportHandshakeFailure(conn, err) + } } + err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn) if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - err = N.ReportPacketConnHandshakeSuccess(conn, outConn) - if err != nil { - outConn.Close() + outPacketConn.Close() return err } if destinationAddress.IsValid() { if metadata.Destination.IsFqdn() { if metadata.UDPDisableDomainUnmapping { - outConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) + outPacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) } else { - outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) + outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) } } if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { @@ -104,37 +126,63 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, case C.ProtocolDNS: ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout) } - return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) + return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn)) } func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error { + defer conn.Close() ctx = adapter.WithContext(ctx, &metadata) - var outConn net.PacketConn - var destinationAddress netip.Addr - var err error - if len(metadata.DestinationAddresses) > 0 { - outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) - } else if metadata.Destination.IsFqdn() { - var destinationAddresses []netip.Addr - destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) + var ( + outPacketConn net.PacketConn + outConn net.Conn + destinationAddress netip.Addr + err error + ) + if metadata.UDPConnect { + if len(metadata.DestinationAddresses) > 0 { + outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) + } else if metadata.Destination.IsFqdn() { + var destinationAddresses []netip.Addr + destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) + if err != nil { + return N.ReportHandshakeFailure(conn, err) + } + outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, destinationAddresses) + } else { + outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) + } if err != nil { return N.ReportHandshakeFailure(conn, err) } - outConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses) + connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr()) + if connRemoteAddr != metadata.Destination.Addr { + destinationAddress = connRemoteAddr + } } else { - outConn, err = this.ListenPacket(ctx, metadata.Destination) + if len(metadata.DestinationAddresses) > 0 { + outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) + } else if metadata.Destination.IsFqdn() { + var destinationAddresses []netip.Addr + destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) + if err != nil { + return N.ReportHandshakeFailure(conn, err) + } + outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses) + } else { + outPacketConn, err = this.ListenPacket(ctx, metadata.Destination) + } + if err != nil { + return N.ReportHandshakeFailure(conn, err) + } } + err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn) if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - err = N.ReportPacketConnHandshakeSuccess(conn, outConn) - if err != nil { - outConn.Close() + outPacketConn.Close() return err } if destinationAddress.IsValid() { if metadata.Destination.IsFqdn() { - outConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) + outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) } if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { natConn.UpdateDestination(destinationAddress) @@ -148,7 +196,7 @@ func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this case C.ProtocolDNS: ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout) } - return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outConn)) + return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn)) } func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error {