Add disableCache/disableExpire option for dns client

This commit is contained in:
世界 2022-07-06 23:39:17 +08:00
parent 8a761d7e3b
commit ecac383477
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
5 changed files with 138 additions and 105 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/option"
"golang.org/x/net/dns/dnsmessage"
)
@ -27,12 +28,21 @@ var (
var _ adapter.DNSClient = (*Client)(nil)
type Client struct {
cache *cache.LruCache[dnsmessage.Question, dnsmessage.Message]
cache *cache.LruCache[dnsmessage.Question, *dnsmessage.Message]
disableCache bool
disableExpire bool
}
func NewClient() *Client {
return &Client{
cache: cache.New[dnsmessage.Question, dnsmessage.Message](),
func NewClient(options option.DNSClientOptions) *Client {
if options.DisableCache {
return &Client{
disableCache: true,
}
} else {
return &Client{
cache: cache.New[dnsmessage.Question, *dnsmessage.Message](),
disableExpire: options.DisableExpire,
}
}
}
@ -41,10 +51,12 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
return nil, E.New("empty query")
}
question := message.Questions[0]
cachedAnswer, cached := c.cache.Load(question)
if cached {
cachedAnswer.ID = message.ID
return &cachedAnswer, nil
if !c.disableCache {
cachedAnswer, cached := c.cache.Load(question)
if cached {
cachedAnswer.ID = message.ID
return cachedAnswer, nil
}
}
if !transport.Raw() {
if question.Type == dnsmessage.TypeA || question.Type == dnsmessage.TypeAAAA {
@ -56,7 +68,9 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
if err != nil {
return nil, err
}
c.cache.StoreWithExpire(question, *response, calculateExpire(message))
if !c.disableCache {
c.storeCache(question, response)
}
return message, err
}
@ -93,37 +107,39 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
}
return sortAddresses(response4, response6, strategy), nil
}
if strategy == C.DomainStrategyUseIPv4 {
response, err := c.questionCache(dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
if err != ErrNotCached {
return response, err
}
} else if strategy == C.DomainStrategyUseIPv6 {
response, err := c.questionCache(dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
})
if err != ErrNotCached {
return response, err
}
} else {
response4, _ := c.questionCache(dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
response6, _ := c.questionCache(dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
})
if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), nil
if !c.disableCache {
if strategy == C.DomainStrategyUseIPv4 {
response, err := c.questionCache(dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
if err != ErrNotCached {
return response, err
}
} else if strategy == C.DomainStrategyUseIPv6 {
response, err := c.questionCache(dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
})
if err != ErrNotCached {
return response, err
}
} else {
response4, _ := c.questionCache(dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
})
response6, _ := c.questionCache(dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
})
if len(response4) > 0 || len(response6) > 0 {
return sortAddresses(response4, response6, strategy), nil
}
}
}
var rCode dnsmessage.RCode
@ -135,70 +151,74 @@ func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, dom
} else {
rCode = dnsmessage.RCode(rCodeError)
}
if c.disableCache {
return nil, err
}
}
header := dnsmessage.Header{
Response: true,
Authoritative: true,
RCode: rCode,
}
expire := time.Now().Add(time.Second * time.Duration(DefaultTTL))
if strategy != C.DomainStrategyUseIPv6 {
question4 := dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
}
response4 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is4() || addr.Is4In6()
})
message4 := dnsmessage.Message{
Header: header,
Questions: []dnsmessage.Question{question4},
}
if len(response4) > 0 {
for _, address := range response4 {
message4.Answers = append(message4.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: question4.Name,
Class: question4.Class,
TTL: DefaultTTL,
},
Body: &dnsmessage.AResource{
A: address.As4(),
},
})
if !c.disableCache {
if strategy != C.DomainStrategyUseIPv6 {
question4 := dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeA,
Class: dnsmessage.ClassINET,
}
}
c.cache.StoreWithExpire(question4, message4, expire)
}
if strategy != C.DomainStrategyUseIPv4 {
question6 := dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
}
response6 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is6() && !addr.Is4In6()
})
message6 := dnsmessage.Message{
Header: header,
Questions: []dnsmessage.Question{question6},
}
if len(response6) > 0 {
for _, address := range response6 {
message6.Answers = append(message6.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: question6.Name,
Class: question6.Class,
TTL: DefaultTTL,
},
Body: &dnsmessage.AAAAResource{
AAAA: address.As16(),
},
})
response4 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is4() || addr.Is4In6()
})
message4 := &dnsmessage.Message{
Header: header,
Questions: []dnsmessage.Question{question4},
}
if len(response4) > 0 {
for _, address := range response4 {
message4.Answers = append(message4.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: question4.Name,
Class: question4.Class,
TTL: DefaultTTL,
},
Body: &dnsmessage.AResource{
A: address.As4(),
},
})
}
}
c.storeCache(question4, message4)
}
if strategy != C.DomainStrategyUseIPv4 {
question6 := dnsmessage.Question{
Name: dnsName,
Type: dnsmessage.TypeAAAA,
Class: dnsmessage.ClassINET,
}
response6 := common.Filter(response, func(addr netip.Addr) bool {
return addr.Is6() && !addr.Is4In6()
})
message6 := &dnsmessage.Message{
Header: header,
Questions: []dnsmessage.Question{question6},
}
if len(response6) > 0 {
for _, address := range response6 {
message6.Answers = append(message6.Answers, dnsmessage.Resource{
Header: dnsmessage.ResourceHeader{
Name: question6.Name,
Class: question6.Class,
TTL: DefaultTTL,
},
Body: &dnsmessage.AAAAResource{
AAAA: address.As16(),
},
})
}
}
c.storeCache(question6, message6)
}
c.cache.StoreWithExpire(question6, message6, expire)
}
return response, err
}
@ -211,14 +231,19 @@ func sortAddresses(response4 []netip.Addr, response6 []netip.Addr, strategy C.Do
}
}
func calculateExpire(message *dnsmessage.Message) time.Time {
func (c *Client) storeCache(question dnsmessage.Question, message *dnsmessage.Message) {
if c.disableExpire {
c.cache.Store(question, message)
return
}
timeToLive := DefaultTTL
for _, answer := range message.Answers {
if int(answer.Header.TTL) < timeToLive {
timeToLive = int(answer.Header.TTL)
}
}
return time.Now().Add(time.Second * time.Duration(timeToLive))
expire := time.Now().Add(time.Second * time.Duration(timeToLive))
c.cache.StoreWithExpire(question, message, expire)
}
func (c *Client) exchangeToLookup(ctx context.Context, transport adapter.DNSTransport, message *dnsmessage.Message, question dnsmessage.Question) (*dnsmessage.Message, error) {
@ -275,9 +300,11 @@ func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTran
Type: qType,
Class: dnsmessage.ClassINET,
}
cachedAddresses, err := c.questionCache(question)
if err != ErrNotCached {
return cachedAddresses, err
if !c.disableCache {
cachedAddresses, err := c.questionCache(question)
if err != ErrNotCached {
return cachedAddresses, err
}
}
message := dnsmessage.Message{
Header: dnsmessage.Header{
@ -298,7 +325,7 @@ func (c *Client) questionCache(question dnsmessage.Question) ([]netip.Addr, erro
if !cached {
return nil, ErrNotCached
}
return messageToAddresses(&response)
return messageToAddresses(response)
}
func messageToAddresses(response *dnsmessage.Message) ([]netip.Addr, error) {

View File

@ -91,7 +91,7 @@ func (t *TCPTransport) newConnection(conn *dnsConnection) {
cancel()
conn.err = err
if err != nil {
t.logger.Warn("connection closed: ", err)
t.logger.Debug("connection closed: ", err)
}
}

View File

@ -99,7 +99,7 @@ func (t *TLSTransport) newConnection(conn *dnsConnection) {
cancel()
conn.err = err
if err != nil {
t.logger.Warn("connection closed: ", err)
t.logger.Debug("connection closed: ", err)
}
}

View File

@ -87,7 +87,7 @@ func (t *UDPTransport) newConnection(conn *dnsConnection) {
cancel()
conn.err = err
if err != nil {
t.logger.Warn("connection closed: ", err)
t.logger.Debug("connection closed: ", err)
}
}

View File

@ -2,11 +2,17 @@ package option
type DNSOptions struct {
Servers []DNSServerOptions `json:"servers,omitempty"`
DNSClientOptions
}
type DNSClientOptions struct {
DisableCache bool `json:"disable_cache,omitempty"`
DisableExpire bool `json:"disable_expire,omitempty"`
}
type DNSServerOptions struct {
Tag string `json:"tag,omitempty"`
Address string `json:"address"`
Detour string `json:"detour,omitempty"`
AddressResolver string `json:"address_resolver,omitempty"`
DialerOptions
}