From 8b9813079b85e5020dbf5ddb8f321c647944c136 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Mon, 4 Mar 2024 22:12:08 +0800 Subject: [PATCH] chore: share RelayDnsPacket function code --- adapter/outbound/dns.go | 34 ++++---------- component/resolver/relay.go | 88 +++++++++++++++++++++++++++++++++++++ listener/sing_tun/dns.go | 88 +++---------------------------------- 3 files changed, 102 insertions(+), 108 deletions(-) create mode 100644 component/resolver/relay.go diff --git a/adapter/outbound/dns.go b/adapter/outbound/dns.go index 94819749..405392a1 100644 --- a/adapter/outbound/dns.go +++ b/adapter/outbound/dns.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net" - "net/netip" "time" "github.com/metacubex/mihomo/common/pool" @@ -12,8 +11,6 @@ import ( "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" "github.com/metacubex/mihomo/log" - - D "github.com/miekg/dns" ) type Dns struct { @@ -79,12 +76,12 @@ func (d *dnsPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { } func (d *dnsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - ctx, cancel := context.WithTimeout(d.ctx, time.Second*5) + ctx, cancel := context.WithTimeout(d.ctx, resolver.DefaultDnsRelayTimeout) defer cancel() - buf := pool.Get(2048) + buf := pool.Get(resolver.SafeDnsPacketSize) put := func() { _ = pool.Put(buf) } - buf, err = RelayDnsPacket(ctx, p, buf) + buf, err = resolver.RelayDnsPacket(ctx, p, buf) if err != nil { put() return 0, err @@ -110,7 +107,11 @@ func (d *dnsPacketConn) Close() error { } func (*dnsPacketConn) LocalAddr() net.Addr { - return net.UDPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:53")) + return &net.UDPAddr{ + IP: net.IPv4(127, 0, 0, 1), + Port: 53, + Zone: "", + } } func (*dnsPacketConn) SetDeadline(t time.Time) error { @@ -139,22 +140,3 @@ func NewDnsWithOption(option DnsOption) *Dns { }, } } - -// copied from listener/sing_mux/dns.go -func RelayDnsPacket(ctx context.Context, payload []byte, target []byte) ([]byte, error) { - msg := &D.Msg{} - if err := msg.Unpack(payload); err != nil { - return nil, err - } - - r, err := resolver.ServeMsg(ctx, msg) - if err != nil { - m := new(D.Msg) - m.SetRcode(msg, D.RcodeServerFailure) - return m.PackBuffer(target) - } - - r.SetRcode(msg, r.Rcode) - r.Compress = true - return r.PackBuffer(target) -} diff --git a/component/resolver/relay.go b/component/resolver/relay.go new file mode 100644 index 00000000..3bc54445 --- /dev/null +++ b/component/resolver/relay.go @@ -0,0 +1,88 @@ +package resolver + +import ( + "context" + "encoding/binary" + "io" + "net" + "time" + + "github.com/metacubex/mihomo/common/pool" + + D "github.com/miekg/dns" +) + +const DefaultDnsReadTimeout = time.Second * 10 +const DefaultDnsRelayTimeout = time.Second * 5 + +const SafeDnsPacketSize = 2 * 1024 // safe size which is 1232 from https://dnsflagday.net/2020/, so 2048 is enough + +func RelayDnsConn(ctx context.Context, conn net.Conn) error { + buff := pool.Get(pool.UDPBufferSize) + defer func() { + _ = pool.Put(buff) + _ = conn.Close() + }() + for { + if conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) != nil { + break + } + + length := uint16(0) + if err := binary.Read(conn, binary.BigEndian, &length); err != nil { + break + } + + if int(length) > len(buff) { + break + } + + n, err := io.ReadFull(conn, buff[:length]) + if err != nil { + break + } + + err = func() error { + ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout) + defer cancel() + inData := buff[:n] + msg, err := RelayDnsPacket(ctx, inData, buff) + if err != nil { + return err + } + + err = binary.Write(conn, binary.BigEndian, uint16(len(msg))) + if err != nil { + return err + } + + _, err = conn.Write(msg) + if err != nil { + return err + } + return nil + }() + if err != nil { + return err + } + } + return nil +} + +func RelayDnsPacket(ctx context.Context, payload []byte, target []byte) ([]byte, error) { + msg := &D.Msg{} + if err := msg.Unpack(payload); err != nil { + return nil, err + } + + r, err := ServeMsg(ctx, msg) + if err != nil { + m := new(D.Msg) + m.SetRcode(msg, D.RcodeServerFailure) + return m.PackBuffer(target) + } + + r.SetRcode(msg, r.Rcode) + r.Compress = true + return r.PackBuffer(target) +} diff --git a/listener/sing_tun/dns.go b/listener/sing_tun/dns.go index 056c9169..86237daa 100644 --- a/listener/sing_tun/dns.go +++ b/listener/sing_tun/dns.go @@ -2,29 +2,21 @@ package sing_tun import ( "context" - "encoding/binary" - "io" "net" "net/netip" "sync" "time" - "github.com/metacubex/mihomo/common/pool" "github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/listener/sing" "github.com/metacubex/mihomo/log" - D "github.com/miekg/dns" - "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/network" ) -const DefaultDnsReadTimeout = time.Second * 10 -const DefaultDnsRelayTimeout = time.Second * 5 - type ListenerHandler struct { *sing.ListenerHandler DnsAdds []netip.AddrPort @@ -45,61 +37,11 @@ func (h *ListenerHandler) ShouldHijackDns(targetAddr netip.AddrPort) bool { func (h *ListenerHandler) NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error { if h.ShouldHijackDns(metadata.Destination.AddrPort()) { log.Debugln("[DNS] hijack tcp:%s", metadata.Destination.String()) - buff := pool.Get(pool.UDPBufferSize) - defer func() { - _ = pool.Put(buff) - _ = conn.Close() - }() - for { - if conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) != nil { - break - } - - length := uint16(0) - if err := binary.Read(conn, binary.BigEndian, &length); err != nil { - break - } - - if int(length) > len(buff) { - break - } - - n, err := io.ReadFull(conn, buff[:length]) - if err != nil { - break - } - - err = func() error { - ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout) - defer cancel() - inData := buff[:n] - msg, err := RelayDnsPacket(ctx, inData, buff) - if err != nil { - return err - } - - err = binary.Write(conn, binary.BigEndian, uint16(len(msg))) - if err != nil { - return err - } - - _, err = conn.Write(msg) - if err != nil { - return err - } - return nil - }() - if err != nil { - return err - } - } - return nil + return resolver.RelayDnsConn(ctx, conn) } return h.ListenerHandler.NewConnection(ctx, conn, metadata) } -const SafeDnsPacketSize = 2 * 1024 // safe size which is 1232 from https://dnsflagday.net/2020/, so 2048 is enough - func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network.PacketConn, metadata M.Metadata) error { if h.ShouldHijackDns(metadata.Destination.AddrPort()) { log.Debugln("[DNS] hijack udp:%s from %s", metadata.Destination.String(), metadata.Source.String()) @@ -114,7 +56,7 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. rwOptions := network.ReadWaitOptions{ FrontHeadroom: network.CalculateFrontHeadroom(conn), RearHeadroom: network.CalculateRearHeadroom(conn), - MTU: SafeDnsPacketSize, + MTU: resolver.SafeDnsPacketSize, } readWaiter, isReadWaiter := bufio.CreatePacketReadWaiter(conn) if isReadWaiter { @@ -126,7 +68,7 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. dest M.Socksaddr err error ) - _ = conn.SetReadDeadline(time.Now().Add(DefaultDnsReadTimeout)) + _ = conn.SetReadDeadline(time.Now().Add(resolver.DefaultDnsReadTimeout)) readBuff = nil // clear last loop status, avoid repeat release if isReadWaiter { readBuff, dest, err = readWaiter.WaitReadPacket() @@ -147,15 +89,15 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. return err } go func() { - ctx, cancel := context.WithTimeout(ctx, DefaultDnsRelayTimeout) + ctx, cancel := context.WithTimeout(ctx, resolver.DefaultDnsRelayTimeout) defer cancel() inData := readBuff.Bytes() writeBuff := readBuff writeBuff.Resize(writeBuff.Start(), 0) - if len(writeBuff.FreeBytes()) < SafeDnsPacketSize { // only create a new buffer when space don't enough + if len(writeBuff.FreeBytes()) < resolver.SafeDnsPacketSize { // only create a new buffer when space don't enough writeBuff = rwOptions.NewPacketBuffer() } - msg, err := RelayDnsPacket(ctx, inData, writeBuff.FreeBytes()) + msg, err := resolver.RelayDnsPacket(ctx, inData, writeBuff.FreeBytes()) if writeBuff != readBuff { readBuff.Release() } @@ -182,21 +124,3 @@ func (h *ListenerHandler) NewPacketConnection(ctx context.Context, conn network. } return h.ListenerHandler.NewPacketConnection(ctx, conn, metadata) } - -func RelayDnsPacket(ctx context.Context, payload []byte, target []byte) ([]byte, error) { - msg := &D.Msg{} - if err := msg.Unpack(payload); err != nil { - return nil, err - } - - r, err := resolver.ServeMsg(ctx, msg) - if err != nil { - m := new(D.Msg) - m.SetRcode(msg, D.RcodeServerFailure) - return m.PackBuffer(target) - } - - r.SetRcode(msg, r.Rcode) - r.Compress = true - return r.PackBuffer(target) -}