Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize: cache NXDomain and reject with 0.0.0.0/:: #63

Merged
merged 3 commits into from
Apr 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions common/netutils/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,21 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host
func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ dnsmessage.Type, tcp bool) (ans []dnsmessage.Resource, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
fqdn := host
if !strings.HasSuffix(fqdn, ".") {
fqdn += "."
}
switch typ {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
if addr, err := netip.ParseAddr(host); err == nil {
if (addr.Is4() || addr.Is4In6()) && typ == dnsmessage.TypeA {
return []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Type: typ,
Name: dnsmessage.MustNewName(fqdn),
Class: dnsmessage.ClassINET,
TTL: 0,
Type: typ,
},
Body: &dnsmessage.AResource{A: addr.As4()},
},
Expand All @@ -158,7 +165,10 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
return []dnsmessage.Resource{
{
Header: dnsmessage.ResourceHeader{
Type: typ,
Name: dnsmessage.MustNewName(fqdn),
Class: dnsmessage.ClassINET,
TTL: 0,
Type: typ,
},
Body: &dnsmessage.AAAAResource{AAAA: addr.As16()},
},
Expand All @@ -181,10 +191,6 @@ func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host st
if err = builder.StartQuestions(); err != nil {
return nil, err
}
fqdn := host
if !strings.HasSuffix(fqdn, ".") {
fqdn += "."
}
if err = builder.Question(dnsmessage.Question{
Name: dnsmessage.MustNewName(fqdn),
Type: typ,
Expand Down
9 changes: 7 additions & 2 deletions control/control_plane_core.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,12 +523,17 @@ func (c *controlPlaneCore) BatchUpdateDomainRouting(cache *DnsCache) error {
// Parse ips from DNS resp answers.
var ips []netip.Addr
for _, ans := range cache.Answers {
var ip netip.Addr
switch ans.Header.Type {
case dnsmessage.TypeA:
ips = append(ips, netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A))
ip = netip.AddrFrom4(ans.Body.(*dnsmessage.AResource).A)
case dnsmessage.TypeAAAA:
ips = append(ips, netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA))
ip = netip.AddrFrom16(ans.Body.(*dnsmessage.AAAAResource).AAAA)
}
if ip.IsUnspecified() {
continue
}
ips = append(ips, ip)
}
if len(ips) == 0 {
return nil
Expand Down
10 changes: 10 additions & 0 deletions control/dns_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,13 @@ func (c *DnsCache) IncludeIp(ip netip.Addr) bool {
}
return false
}

func (c *DnsCache) IncludeAnyIp() bool {
for _, ans := range c.Answers {
switch ans.Body.(type) {
case *dnsmessage.AResource, *dnsmessage.AAAAResource:
return true
}
}
return false
}
94 changes: 74 additions & 20 deletions control/dns_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import (

const (
MaxDnsLookupDepth = 3
minFirefoxCacheTimeout = 120 * time.Second
minFirefoxCacheTtl = 120
minFirefoxCacheTimeout = minFirefoxCacheTtl * time.Second
)

type IpVersionPrefer int
Expand All @@ -49,6 +50,11 @@ var (
UnsupportedQuestionTypeError = fmt.Errorf("unsupported question type")
)

var (
UnspecifiedAddressA = netip.MustParseAddr("0.0.0.0")
UnspecifiedAddressAAAA = netip.MustParseAddr("::")
)

type DnsControllerOption struct {
Log *logrus.Logger
CacheAccessCallback func(cache *DnsCache) (err error)
Expand Down Expand Up @@ -125,14 +131,12 @@ func (c *DnsController) RemoveDnsRespCache(qname string, qtype dnsmessage.Type)
c.dnsCacheMu.Unlock()
}
func (c *DnsController) LookupDnsRespCache(qname string, qtype dnsmessage.Type) (cache *DnsCache) {
now := time.Now()

c.dnsCacheMu.Lock()
cache, ok := c.dnsCache[c.cacheKey(qname, qtype)]
c.dnsCacheMu.Unlock()
// We should make sure the remaining TTL is greater than 120s (minFirefoxCacheTimeout), or
// return nil and request a new lookup to refresh the cache.
if ok && cache.Deadline.After(now.Add(minFirefoxCacheTimeout)) {
if ok {
return cache
}
return nil
Expand Down Expand Up @@ -187,35 +191,52 @@ func (c *DnsController) DnsRespHandler(data []byte, validateRushAns bool) (newMs
return &msg, nil
}

// Get TTL.
var ttl uint32
for i := range msg.Answers {
if ttl == 0 {
ttl = msg.Answers[i].Header.TTL
break
}
}
if ttl == 0 {
// It seems no answers (NXDomain).
ttl = minFirefoxCacheTtl
}

// Check req type.
switch q.Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
default:
// Update DnsCache.
if err = c.updateDnsCache(&msg, ttl, &q); err != nil {
return nil, err
}
return &msg, nil
}

// Set ttl.
var ttl uint32
for i := range msg.Answers {
if ttl == 0 {
ttl = msg.Answers[i].Header.TTL
}
// Set TTL = zero. This requests applications must resend every request.
// However, it may be not defined in the standard.
msg.Answers[i].Header.TTL = 0
}

// Check if there is any A/AAAA record.
var hasIpRecord bool
// Check if request A/AAAA record.
var reqIpRecord bool
loop:
for i := range msg.Answers {
switch msg.Answers[i].Header.Type {
for i := range msg.Questions {
switch msg.Questions[i].Type {
case dnsmessage.TypeA, dnsmessage.TypeAAAA:
hasIpRecord = true
reqIpRecord = true
break loop
}
}
if !hasIpRecord {
if !reqIpRecord {
// Update DnsCache.
if err = c.updateDnsCache(&msg, ttl, &q); err != nil {
return nil, err
}
return &msg, nil
}

Expand All @@ -236,6 +257,15 @@ loop:
}
}

// Update DnsCache.
if err = c.updateDnsCache(&msg, ttl, &q); err != nil {
return nil, err
}
// Pack to get newData.
return &msg, nil
}

func (c *DnsController) updateDnsCache(msg *dnsmessage.Message, ttl uint32, q *dnsmessage.Question) error {
// Update DnsCache.
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
Expand All @@ -252,11 +282,10 @@ loop:
}
cacheTimeout += 5 * time.Second // DNS lookup timeout.

if err = c.UpdateDnsCache(q.Name.String(), q.Type.String(), msg.Answers, time.Now().Add(cacheTimeout)); err != nil {
return nil, err
if err := c.UpdateDnsCache(q.Name.String(), q.Type.String(), msg.Answers, time.Now().Add(cacheTimeout)); err != nil {
return err
}
// Pack to get newData.
return &msg, nil
return nil
}

func (c *DnsController) UpdateDnsCache(host string, dnsTyp string, answers []dnsmessage.Resource, deadline time.Time) (err error) {
Expand Down Expand Up @@ -407,7 +436,7 @@ func (c *DnsController) Handle_(dnsMessage *dnsmessage.Message, req *udpRequest)
}
// resp is valid.
cache2 := c.LookupDnsRespCache(qname, qtype2)
if c.qtypePrefer == qtype || cache2 == nil {
if c.qtypePrefer == qtype || cache2 == nil || !cache2.IncludeAnyIp() {
return sendPkt(resp, req.realDst, req.realSrc, req.src, req.lConn, req.lanWanFlag)
} else {
return c.sendReject_(dnsMessage, req)
Expand Down Expand Up @@ -490,14 +519,39 @@ func (c *DnsController) handle_(
// sendReject_ send empty answer.
func (c *DnsController) sendReject_(dnsMessage *dnsmessage.Message, req *udpRequest) (err error) {
dnsMessage.Answers = nil
if len(dnsMessage.Questions) > 0 {
q := dnsMessage.Questions[0]
switch typ := q.Type; typ {
case dnsmessage.TypeA:
dnsMessage.Answers = []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: q.Name,
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0,
},
Body: &dnsmessage.AResource{A: UnspecifiedAddressA.As4()},
}}
case dnsmessage.TypeAAAA:
dnsMessage.Answers = []dnsmessage.Resource{{
Header: dnsmessage.ResourceHeader{
Name: q.Name,
Type: typ,
Class: dnsmessage.ClassINET,
TTL: 0,
},
Body: &dnsmessage.AAAAResource{AAAA: UnspecifiedAddressAAAA.As16()},
}}
}
}
dnsMessage.RCode = dnsmessage.RCodeSuccess
dnsMessage.Response = true
dnsMessage.RecursionAvailable = true
dnsMessage.Truncated = false
if c.log.IsLevelEnabled(logrus.TraceLevel) {
c.log.WithFields(logrus.Fields{
"question": dnsMessage.Questions,
}).Traceln("Reject with empty answer")
}).Traceln("Reject")
}
data, err := dnsMessage.Pack()
if err != nil {
Expand Down