sing-box/common/sniff/dns.go

61 lines
1.5 KiB
Go
Raw Normal View History

2022-07-06 12:39:44 +08:00
package sniff
import (
"context"
"encoding/binary"
"io"
"os"
"time"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/task"
2022-07-06 15:01:09 +08:00
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
2022-07-06 12:39:44 +08:00
"golang.org/x/net/dns/dnsmessage"
)
func StreamDomainNameQuery(readCtx context.Context, reader io.Reader) (*adapter.InboundContext, error) {
var length uint16
err := binary.Read(reader, binary.BigEndian, &length)
if err != nil {
return nil, err
}
if length > 512 {
return nil, os.ErrInvalid
}
_buffer := buf.StackNewSize(int(length))
defer common.KeepAlive(_buffer)
buffer := common.Dup(_buffer)
defer buffer.Release()
readCtx, cancel := context.WithTimeout(readCtx, time.Millisecond*100)
err = task.Run(readCtx, func() error {
return common.Error(buffer.ReadFullFrom(reader, buffer.FreeLen()))
})
cancel()
if err != nil {
return nil, err
}
return DomainNameQuery(readCtx, buffer.Bytes())
}
func DomainNameQuery(ctx context.Context, packet []byte) (*adapter.InboundContext, error) {
var parser dnsmessage.Parser
_, err := parser.Start(packet)
if err != nil {
return nil, err
}
question, err := parser.Question()
if err != nil {
return nil, os.ErrInvalid
}
domain := question.Name.String()
if question.Class == dnsmessage.ClassINET && (question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA) && IsDomainName(domain) {
return &adapter.InboundContext{Protocol: C.ProtocolDNS, Domain: domain}, nil
}
return nil, os.ErrInvalid
}