diff --git a/adapter/outbound/direct.go b/adapter/outbound/direct.go index 7def7b20..1b01a576 100644 --- a/adapter/outbound/direct.go +++ b/adapter/outbound/direct.go @@ -3,18 +3,18 @@ package outbound import ( "context" "errors" - "fmt" "net/netip" N "github.com/metacubex/mihomo/common/net" "github.com/metacubex/mihomo/component/dialer" + "github.com/metacubex/mihomo/component/loopback" "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" ) type Direct struct { *Base - loopBack *loopBackDetector + loopBack *loopback.Detector } type DirectOption struct { @@ -24,8 +24,8 @@ type DirectOption struct { // DialContext implements C.ProxyAdapter func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { - if d.loopBack.CheckConn(metadata.SourceAddrPort()) { - return nil, fmt.Errorf("reject loopback connection to: %s", metadata.RemoteAddress()) + if err := d.loopBack.CheckConn(metadata); err != nil { + return nil, err } opts = append(opts, dialer.WithResolver(resolver.DefaultResolver)) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) @@ -38,8 +38,8 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ... // ListenPacketContext implements C.ProxyAdapter func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { - if d.loopBack.CheckPacketConn(metadata.SourceAddrPort()) { - return nil, fmt.Errorf("reject loopback connection to: %s", metadata.RemoteAddress()) + if err := d.loopBack.CheckPacketConn(metadata); err != nil { + return nil, err } // net.UDPConn.WriteTo only working with *net.UDPAddr, so we need a net.UDPAddr if !metadata.Resolved() { @@ -68,7 +68,7 @@ func NewDirectWithOption(option DirectOption) *Direct { rmark: option.RoutingMark, prefer: C.NewDNSPrefer(option.IPVersion), }, - loopBack: newLoopBackDetector(), + loopBack: loopback.NewDetector(), } } @@ -80,7 +80,7 @@ func NewDirect() *Direct { udp: true, prefer: C.DualStack, }, - loopBack: newLoopBackDetector(), + loopBack: loopback.NewDetector(), } } @@ -92,6 +92,6 @@ func NewCompatible() *Direct { udp: true, prefer: C.DualStack, }, - loopBack: newLoopBackDetector(), + loopBack: loopback.NewDetector(), } } diff --git a/adapter/outbound/direct_loopback_detect.go b/component/loopback/detector.go similarity index 58% rename from adapter/outbound/direct_loopback_detect.go rename to component/loopback/detector.go index 410d5a2f..b07270ed 100644 --- a/adapter/outbound/direct_loopback_detect.go +++ b/component/loopback/detector.go @@ -1,6 +1,8 @@ -package outbound +package loopback import ( + "errors" + "fmt" "net/netip" "github.com/metacubex/mihomo/common/callback" @@ -9,19 +11,21 @@ import ( "github.com/puzpuzpuz/xsync/v3" ) -type loopBackDetector struct { +var ErrReject = errors.New("reject loopback connection") + +type Detector struct { connMap *xsync.MapOf[netip.AddrPort, struct{}] packetConnMap *xsync.MapOf[netip.AddrPort, struct{}] } -func newLoopBackDetector() *loopBackDetector { - return &loopBackDetector{ +func NewDetector() *Detector { + return &Detector{ connMap: xsync.NewMapOf[netip.AddrPort, struct{}](), packetConnMap: xsync.NewMapOf[netip.AddrPort, struct{}](), } } -func (l *loopBackDetector) NewConn(conn C.Conn) C.Conn { +func (l *Detector) NewConn(conn C.Conn) C.Conn { metadata := C.Metadata{} if metadata.SetRemoteAddr(conn.LocalAddr()) != nil { return conn @@ -36,7 +40,7 @@ func (l *loopBackDetector) NewConn(conn C.Conn) C.Conn { }) } -func (l *loopBackDetector) NewPacketConn(conn C.PacketConn) C.PacketConn { +func (l *Detector) NewPacketConn(conn C.PacketConn) C.PacketConn { metadata := C.Metadata{} if metadata.SetRemoteAddr(conn.LocalAddr()) != nil { return conn @@ -51,18 +55,24 @@ func (l *loopBackDetector) NewPacketConn(conn C.PacketConn) C.PacketConn { }) } -func (l *loopBackDetector) CheckConn(connAddr netip.AddrPort) bool { +func (l *Detector) CheckConn(metadata *C.Metadata) error { + connAddr := metadata.SourceAddrPort() if !connAddr.IsValid() { - return false + return nil } - _, ok := l.connMap.Load(connAddr) - return ok + if _, ok := l.connMap.Load(connAddr); ok { + return fmt.Errorf("%w to: %s", ErrReject, metadata.RemoteAddress()) + } + return nil } -func (l *loopBackDetector) CheckPacketConn(connAddr netip.AddrPort) bool { +func (l *Detector) CheckPacketConn(metadata *C.Metadata) error { + connAddr := metadata.SourceAddrPort() if !connAddr.IsValid() { - return false + return nil } - _, ok := l.packetConnMap.Load(connAddr) - return ok + if _, ok := l.packetConnMap.Load(connAddr); ok { + return fmt.Errorf("%w to: %s", ErrReject, metadata.RemoteAddress()) + } + return nil } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index d5a226e9..608ab2c5 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -12,6 +12,7 @@ import ( "time" N "github.com/metacubex/mihomo/common/net" + "github.com/metacubex/mihomo/component/loopback" "github.com/metacubex/mihomo/component/nat" P "github.com/metacubex/mihomo/component/process" "github.com/metacubex/mihomo/component/resolver" @@ -694,6 +695,9 @@ func shouldStopRetry(err error) bool { if errors.Is(err, resolver.ErrIPv6Disabled) { return true } + if errors.Is(err, loopback.ErrReject) { + return true + } return false }