Skip to content

Commit

Permalink
nim: switch from go-dns-resolver to miekg/dns
Browse files Browse the repository at this point in the history
this allows us to set the source interface when doing
DNS requests

go-dns-resolver itself is using miekg/dns under the hood,
so now nim is using it directly instead

Signed-off-by: Christoph Ostarek <christoph@zededa.com>
  • Loading branch information
christoph-zededa authored and eriknordmark committed Mar 29, 2023
1 parent 74c1c6f commit 8573372
Show file tree
Hide file tree
Showing 3 changed files with 359 additions and 45 deletions.
66 changes: 21 additions & 45 deletions pkg/pillar/cmd/nim/controllerdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ import (
"errors"
"fmt"
"io/fs"
"net"
"os"
"time"

dns "github.com/Focinfi/go-dns-resolver"
"github.com/lf-edge/eve/pkg/pillar/devicenetwork"
"github.com/lf-edge/eve/pkg/pillar/types"
)

Expand Down Expand Up @@ -72,6 +71,18 @@ func (n *nim) queryControllerDNS() {
}
}

func (n *nim) resolveWithPorts(domain string) []devicenetwork.DNSResponse {
dnsResponse, errs := devicenetwork.ResolveWithPortsLambda(
domain,
n.dpcManager.GetDNS(),
devicenetwork.ResolveWithSrcIP,
)
if len(errs) > 0 {
n.Log.Warnf("resolveWithPortsLambda failed: %+v", errs)
}
return dnsResponse
}

// periodical cache the controller DNS resolution into /etc/hosts file
// it returns the cached ip string, and TTL setting from the server
func (n *nim) controllerDNSCache(
Expand All @@ -89,61 +100,26 @@ func (n *nim) controllerDNSCache(
return ipAddrCached, ttlCached
}

nameServers := n.readNameservers()

err := os.Remove(tmpHostFileName)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
n.Log.Warnf("%s exists but removing failed: %+v", tmpHostFileName, err)
}

var newhosts []byte
var gotipentry bool
var lookupIPaddr string
var ttlSec int
var lookupIPaddr string

domains := []string{string(controllerServer)}
dtypes := []dns.QueryType{dns.TypeA}
for _, nameServer := range nameServers {
resolver := dns.NewResolver(nameServer)
resolver.Targets(domains...).Types(dtypes...)

res := resolver.Lookup()
for target := range res.ResMap {
for _, r := range res.ResMap[target] {
dIP := net.ParseIP(r.Content)
if dIP == nil {
continue
}
lookupIPaddr = dIP.String()
ttlSec = getTTL(r.Ttl)
if ipaddrCached == lookupIPaddr {
n.Log.Tracef("same IP address %s, return", lookupIPaddr)
return ipaddrCached, ttlSec
}
serverEntry := fmt.Sprintf("%s %s\n", lookupIPaddr, controllerServer)
newhosts = append(etchosts, []byte(serverEntry)...)
gotipentry = true
// a rare event for dns address change, log it
n.Log.Noticef("dnsServer %s, ttl %d, entry add to /etc/hosts: %s", nameServer, ttlSec, serverEntry)
break
}
if gotipentry {
break
}
}
if gotipentry {
break
}
}
dnsResponses := n.resolveWithPorts(string(controllerServer))
ttlSec = int(dnsResponses[0].TTL)
lookupIPaddr = dnsResponses[0].IP.String()
serverEntry := fmt.Sprintf("%s %s\n", lookupIPaddr, controllerServer)
newhosts = append(etchosts, []byte(serverEntry)...)

if ipaddrCached == lookupIPaddr {
return ipaddrCached, minTTLSec
}
if !gotipentry { // put original /etc/hosts file back
if len(dnsResponses) == 0 {
newhosts = append(newhosts, etchosts...)
}

if n.writeTmpHostsFile(newhosts) && gotipentry {
if len(dnsResponses) > 0 && n.writeHostsFile(newhosts) {
n.Log.Tracef("append controller IP %s to /etc/hosts", lookupIPaddr)
ipaddrCached = lookupIPaddr
} else {
Expand Down
137 changes: 137 additions & 0 deletions pkg/pillar/devicenetwork/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,33 @@ package devicenetwork

import (
"fmt"
"net"
"os"
"path/filepath"
"strings"
"sync"
"time"

"github.com/lf-edge/eve/pkg/pillar/types"
"github.com/miekg/dns"
)

// ResolveConfDirs : directories where resolv.conf for an interface could be found.
var ResolveConfDirs = []string{"/run/dhcpcd/resolv.conf", "/run/wwan/resolv.conf"}

const (
// DNSMaxParallelRequests is the maximum amount of parallel DNS requests
DNSMaxParallelRequests = 5
maxTTLSec int = 3600
dnsTimeout = 30 * time.Second
)

// DNSResponse represents a response from a DNS server (A Record)
type DNSResponse struct {
IP net.IP
TTL uint32
}

// IfnameToResolvConf : Look for a file created by dhcpcd
func IfnameToResolvConf(ifname string) string {
for _, d := range ResolveConfDirs {
Expand All @@ -39,3 +58,121 @@ func ResolvConfToIfname(resolvConf string) string {
}
return ""
}

// ResolveWithSrcIP resolves a domain with a given dns server and source Ip
func ResolveWithSrcIP(domain string, dnsServerIP net.IP, srcIP net.IP) ([]DNSResponse, error) {
var response []DNSResponse
sourceUDPAddr := net.UDPAddr{IP: srcIP}
dialer := net.Dialer{LocalAddr: &sourceUDPAddr}
dnsClient := dns.Client{Dialer: &dialer}
msg := dns.Msg{}
if domain[len(domain)-1] != '.' {
domain = domain + "."
}
msg.SetQuestion(domain, dns.TypeA)
dnsClient.Timeout = time.Duration(dnsTimeout)
reply, _, err := dnsClient.Exchange(&msg, net.JoinHostPort(dnsServerIP.String(), "53"))
if err != nil {
return response, fmt.Errorf("dns exchange failed: %v", err)
}
for _, answer := range reply.Answer {
if aRecord, ok := answer.(*dns.A); ok {
response = append(response, DNSResponse{
IP: aRecord.A,
TTL: aRecord.Header().Ttl,
})
}
}

return response, nil
}

// ResolveWithPortsLambda resolves a domain by using source IPs and dns servers from DeviceNetworkStatus
// As a resolver func ResolveWithSrcIP can be used
func ResolveWithPortsLambda(domain string,
dns types.DeviceNetworkStatus,
resolve func(string, net.IP, net.IP) ([]DNSResponse, error)) ([]DNSResponse, []error) {

quit := make(chan struct{})
work := make(chan struct{}, DNSMaxParallelRequests)
resolvedIPsChan := make(chan []DNSResponse)
countDNSRequests := 0
var errs []error
var errsMutex sync.Mutex
var wg sync.WaitGroup

for _, port := range dns.Ports {
if port.Cost > 0 {
continue
}

var srcIPs []net.IP
for _, addrInfo := range port.AddrInfoList {
srcIPs = append(srcIPs, addrInfo.Addr)
}

for _, dnsIP := range port.DNSServers {
for _, srcIP := range srcIPs {
wg.Add(1)
dnsIPCopy := make(net.IP, len(dnsIP))
copy(dnsIPCopy, dnsIP)
srcIPCopy := make(net.IP, len(srcIP))
copy(srcIPCopy, srcIP)
countDNSRequests++
go func(dnsIP, srcIP net.IP) {
select {
case work <- struct{}{}:
// if writable, means less than dnsMaxParallelRequests goroutines are currently running
}
select {
case <-quit:
// will return in case the quit chan has been closed,
// meaning another dns server already resolved the IP
return
default:
// do not wait for receiving a quit
}
response, err := resolve(domain, dnsIP, srcIP)
if err != nil {
errsMutex.Lock()
defer errsMutex.Unlock()
errs = append(errs, err)
}
if response != nil {
resolvedIPsChan <- response
}
<-work
wg.Done()
}(dnsIPCopy, srcIPCopy)
}
}
}

wgChan := make(chan struct{})
go func() {
wg.Wait()
close(wgChan)
}()

select {
case <-wgChan:
var responses []DNSResponse
if countDNSRequests == 0 {
// fallback in case no resolver is configured
ips, err := net.LookupIP(domain)
if err != nil {
return nil, append(errs, fmt.Errorf("fallback resolver failed: %+v", err))
}
for _, ip := range ips {
responses = append(responses, DNSResponse{
IP: ip,
TTL: uint32(maxTTLSec),
})
}
}
return responses, nil
case ip := <-resolvedIPsChan:
close(quit)
return ip, errs
}
}
Loading

0 comments on commit 8573372

Please sign in to comment.