diff --git a/adapter/outbound/vmess.go b/adapter/outbound/vmess.go index c5212eef..710e19bb 100644 --- a/adapter/outbound/vmess.go +++ b/adapter/outbound/vmess.go @@ -507,9 +507,9 @@ type vmessPacketConn struct { // WriteTo implments C.PacketConn.WriteTo // Since VMess doesn't support full cone NAT by design, we verify if addr matches uc.rAddr, and drop the packet if not. func (uc *vmessPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - allowedAddr := uc.rAddr.(*net.UDPAddr) - destAddr := addr.(*net.UDPAddr) - if !(allowedAddr.IP.Equal(destAddr.IP) && allowedAddr.Port == destAddr.Port) { + allowedAddr := uc.rAddr + destAddr := addr + if allowedAddr.String() != destAddr.String() { return 0, ErrUDPRemoteAddrMismatch } uc.access.Lock() diff --git a/component/proxydialer/proxydialer.go b/component/proxydialer/proxydialer.go index a32e54d1..8a3ab263 100644 --- a/component/proxydialer/proxydialer.go +++ b/component/proxydialer/proxydialer.go @@ -33,8 +33,8 @@ func NewByName(proxyName string, dialer C.Dialer) (C.Dialer, error) { } func (p proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - currentMeta, err := addrToMetadata(address) - if err != nil { + currentMeta := &C.Metadata{Type: C.INNER} + if err := currentMeta.SetRemoteAddress(address); err != nil { return nil, err } if strings.Contains(network, "udp") { // using in wireguard outbound @@ -42,9 +42,14 @@ func (p proxyDialer) DialContext(ctx context.Context, network, address string) ( if err != nil { return nil, err } - return N.NewBindPacketConn(pc, currentMeta.UDPAddr()), nil + var rAddr net.Addr = currentMeta.UDPAddr() + if rAddr == nil { // the domain name was not resolved, will appear in not stream-oriented udp like Shadowsocks/Tuic + rAddr = N.NewCustomAddr("udp", currentMeta.RemoteAddress(), nil) + } + return N.NewBindPacketConn(pc, rAddr), nil } var conn C.Conn + var err error if d, ok := p.dialer.(dialer.Dialer); ok { // first using old function to let mux work conn, err = p.proxy.DialContext(ctx, currentMeta, dialer.WithOption(d.Opt)) } else { @@ -60,8 +65,8 @@ func (p proxyDialer) DialContext(ctx context.Context, network, address string) ( } func (p proxyDialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) { - currentMeta, err := addrToMetadata(rAddrPort.String()) - if err != nil { + currentMeta := &C.Metadata{Type: C.INNER} + if err := currentMeta.SetRemoteAddress(address); err != nil { return nil, err } return p.listenPacket(ctx, currentMeta) @@ -84,27 +89,3 @@ func (p proxyDialer) listenPacket(ctx context.Context, currentMeta *C.Metadata) } return pc, nil } - -func addrToMetadata(rawAddress string) (addr *C.Metadata, err error) { - host, port, err := net.SplitHostPort(rawAddress) - if err != nil { - err = fmt.Errorf("addrToMetadata failed: %w", err) - return - } - - if ip, err := netip.ParseAddr(host); err != nil { - addr = &C.Metadata{ - Host: host, - DstPort: port, - } - } else { - addr = &C.Metadata{ - Host: "", - DstIP: ip.Unmap(), - DstPort: port, - } - } - addr.Type = C.INNER - - return -} diff --git a/constant/metadata.go b/constant/metadata.go index 4ff20305..1c344d5d 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -229,3 +229,21 @@ func (m *Metadata) String() string { func (m *Metadata) Valid() bool { return m.Host != "" || m.DstIP.IsValid() } + +func (m *Metadata) SetRemoteAddress(rawAddress string) error { + host, port, err := net.SplitHostPort(rawAddress) + if err != nil { + return err + } + + if ip, err := netip.ParseAddr(host); err != nil { + m.Host = host + m.DstIP = netip.Addr{} + } else { + m.Host = "" + m.DstIP = ip.Unmap() + } + m.DstPort = port + + return nil +} diff --git a/transport/tuic/conn.go b/transport/tuic/conn.go index 8f63da75..567f6ce5 100644 --- a/transport/tuic/conn.go +++ b/transport/tuic/conn.go @@ -2,7 +2,6 @@ package tuic import ( "net" - "net/netip" "sync" "sync/atomic" "time" @@ -216,11 +215,11 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro } buf := pool.GetBuffer() defer pool.PutBuffer(buf) - addrPort, err := netip.ParseAddrPort(addr.String()) + address, err := NewAddressNetAddr(addr) if err != nil { return } - err = NewPacket(q.connId, uint16(len(p)), NewAddressAddrPort(addrPort), p).WriteTo(buf) + err = NewPacket(q.connId, uint16(len(p)), address, p).WriteTo(buf) if err != nil { return } diff --git a/transport/tuic/protocol.go b/transport/tuic/protocol.go index 570b6e54..a460eecc 100644 --- a/transport/tuic/protocol.go +++ b/transport/tuic/protocol.go @@ -464,6 +464,18 @@ func NewAddress(metadata *C.Metadata) Address { } } +func NewAddressNetAddr(addr net.Addr) (Address, error) { + addrStr := addr.String() + if addrPort, err := netip.ParseAddrPort(addrStr); err == nil { + return NewAddressAddrPort(addrPort), nil + } + metadata := &C.Metadata{} + if err := metadata.SetRemoteAddress(addrStr); err != nil { + return Address{}, err + } + return NewAddress(metadata), nil +} + func NewAddressAddrPort(addrPort netip.AddrPort) Address { var addrType byte port := addrPort.Port()