Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix using miekg dns #3136

Merged
merged 7 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 112 additions & 78 deletions pkg/pillar/cmd/nim/controllerdns.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ package nim

import (
"bytes"
"errors"
"fmt"
"net"
"io/fs"
"os"
"time"

dns "github.com/Focinfi/go-dns-resolver"
"github.com/lf-edge/eve/pkg/pillar/devicenetwork"
"github.com/lf-edge/eve/pkg/pillar/types"
)

Expand Down Expand Up @@ -70,28 +71,104 @@ func (n *nim) queryControllerDNS() {
}
}

// periodical cache the controller DNS resolution into /etc/hosts file
// it returns the cached ip string, and TTL setting from the server
func (n *nim) controllerDNSCache(etchosts, controllerServer []byte, ipaddrCached string) (string, int) {
if len(etchosts) == 0 || len(controllerServer) == 0 {
return ipaddrCached, maxTTLSec
func (n *nim) resolveWithPorts(domain string) []devicenetwork.DNSResponse {
dnsResponse, errs := devicenetwork.ResolveWithPortsLambda(
domain,
n.dpcManager.GetDNS(),
devicenetwork.ResolveWithSrcIP,
)
if len(errs) > 0 {
n.Log.Warnf("resolveWithPortsLambda failed: %+v", errs)
}
return dnsResponse
}

// periodical cache the controller DNS resolution into /etc/hosts file
// it returns the cached ip string, and TTL setting from the server
func (n *nim) controllerDNSCache(
etchosts, controllerServer []byte,
ipaddrCached string,
) (string, int) {
// Check to see if the server domain is already in the /etc/hosts as in eden,
// then skip this DNS queries
if ipaddrCached == "" {
hostsEntries := bytes.Split(etchosts, []byte("\n"))
for _, entry := range hostsEntries {
fields := bytes.Fields(entry)
if len(fields) == 2 {
if bytes.Compare(fields[1], controllerServer) == 0 {
n.Log.Tracef("server entry %s already in /etc/hosts, skip", controllerServer)
return ipaddrCached, maxTTLSec
}
}
isCached, ipAddrCached, ttlCached := n.checkCachedEntry(
etchosts,
controllerServer,
ipaddrCached,
)
if isCached {
return ipAddrCached, ttlCached
}

err := os.Remove(tmpHostFileName)
if err != nil && !errors.Is(err, fs.ErrNotExist) {
n.Log.Warnf("%s exists but removing failed: %+v", tmpHostFileName, err)
}

dnsResponses := n.resolveWithPorts(string(controllerServer))
for _, dnsResponse := range dnsResponses {
if dnsResponse.IP.String() == ipAddrCached {
return ipAddrCached, getTTL(time.Duration(dnsResponse.TTL))
}
}

lookupIPaddr := n.writeHostsFile(dnsResponses, etchosts, controllerServer)
if lookupIPaddr != "" {
n.Log.Tracef("append controller IP %s to /etc/hosts", lookupIPaddr)
}

var ttlSec int

if len(dnsResponses) > 0 {
ipaddrCached = dnsResponses[0].IP.String()
ttlSec = getTTL(time.Duration(dnsResponses[0].TTL))
return ipaddrCached, ttlSec
} else {
return "", ttlSec
}
}

func (n *nim) writeHostsFile(
dnsResponses []devicenetwork.DNSResponse,
etchosts, controllerServer []byte,
) string {
return n.writeHostsFileToDestination(dnsResponses, etchosts, controllerServer, etcHostFileName)
}

func (n *nim) writeHostsFileToDestination(
dnsResponses []devicenetwork.DNSResponse,
etchosts, controllerServer []byte,
destination string,
) string {
var newhosts []byte

var lookupIPaddr string

if len(dnsResponses) == 0 {
newhosts = append(newhosts, etchosts...)
} else {
newhosts = append([]byte{}, etchosts...)
for _, dnsResponse := range dnsResponses {
lookupIPaddr = dnsResponse.IP.String()
serverEntry := fmt.Sprintf("%s %s\n", lookupIPaddr, controllerServer)
newhosts = append(newhosts, []byte(serverEntry)...)
}
}

err := os.WriteFile(tmpHostFileName, newhosts, 0644)
if err != nil {
n.Log.Errorf("can not write /tmp/etchosts file %v", err)
return ""
}
if err := os.Rename(tmpHostFileName, destination); err != nil {
n.Log.Errorf("can not rename %s file %v", destination, err)
return ""
}

return lookupIPaddr
}

func (*nim) readNameservers() []string {
var nameServers []string
dnsServer, _ := os.ReadFile(resolvFileName)
dnsRes := bytes.Split(dnsServer, []byte("\n"))
Expand All @@ -104,74 +181,31 @@ func (n *nim) controllerDNSCache(etchosts, controllerServer []byte, ipaddrCached
if len(nameServers) == 0 {
nameServers = append(nameServers, "8.8.8.8")
}
return nameServers
}

if _, err := os.Stat(tmpHostFileName); err == nil {
_ = os.Remove(tmpHostFileName)
func (n *nim) checkCachedEntry(
etchosts []byte,
controllerServer []byte,
ipaddrCached string,
) (bool, string, int) {
if len(etchosts) == 0 || len(controllerServer) == 0 {
return true, ipaddrCached, maxTTLSec
}

var newhosts []byte
var gotipentry bool
var lookupIPaddr string
var ttlSec int

domains := []string{string(controllerServer)}
dtypes := []dns.QueryType{dns.TypeA}
for _, nameServer := range nameServers {
resolver := dns.NewResolver(nameServer)
resolver.Targets(domains...).Types(dtypes...)

res := resolver.Lookup()
for target := range res.ResMap {
for _, r := range res.ResMap[target] {
dIP := net.ParseIP(r.Content)
if dIP == nil {
continue
}
lookupIPaddr = dIP.String()
ttlSec = getTTL(r.Ttl)
if ipaddrCached == lookupIPaddr {
n.Log.Tracef("same IP address %s, return", lookupIPaddr)
return ipaddrCached, ttlSec
if ipaddrCached == "" {
hostsEntries := bytes.Split(etchosts, []byte("\n"))
for _, entry := range hostsEntries {
fields := bytes.Fields(entry)
if len(fields) == 2 {
if bytes.Compare(fields[1], controllerServer) == 0 {
n.Log.Tracef("server entry %s already in /etc/hosts, skip", controllerServer)
return true, ipaddrCached, maxTTLSec
}
serverEntry := fmt.Sprintf("%s %s\n", lookupIPaddr, controllerServer)
newhosts = append(etchosts, []byte(serverEntry)...)
gotipentry = true
// a rare event for dns address change, log it
n.Log.Noticef("dnsServer %s, ttl %d, entry add to /etc/hosts: %s", nameServer, ttlSec, serverEntry)
break
}
if gotipentry {
break
}
}
if gotipentry {
break
}
}

if ipaddrCached == lookupIPaddr {
return ipaddrCached, minTTLSec
}
if !gotipentry { // put original /etc/hosts file back
newhosts = append(newhosts, etchosts...)
}

ipaddrCached = ""
err := os.WriteFile(tmpHostFileName, newhosts, 0644)
if err == nil {
if err := os.Rename(tmpHostFileName, etcHostFileName); err != nil {
n.Log.Errorf("can not rename /etc/hosts file %v", err)
} else {
if gotipentry {
ipaddrCached = lookupIPaddr
}
n.Log.Tracef("append controller IP %s to /etc/hosts", lookupIPaddr)
}
} else {
n.Log.Errorf("can not write /tmp/etchosts file %v", err)
}

return ipaddrCached, ttlSec
return false, "", 0
}

func getTTL(ttl time.Duration) int {
Expand Down
79 changes: 79 additions & 0 deletions pkg/pillar/cmd/nim/controllerdns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package nim

import (
"fmt"
"io"
"os"
"strings"
"testing"

"github.com/lf-edge/eve/pkg/pillar/base"
"github.com/lf-edge/eve/pkg/pillar/devicenetwork"
"github.com/lf-edge/eve/pkg/pillar/dpcmanager"
"github.com/sirupsen/logrus"
)

func createTestNim() *nim {
var n nim

dpcManager := dpcmanager.DpcManager{}
n.dpcManager = &dpcManager
logger := logrus.StandardLogger()
log := base.NewSourceLogObject(logger, "zedagent", 1234)
n.Logger = logger
n.Log = log

return &n
}

func TestControllerDNSCacheIndexOutOfRange(t *testing.T) {
// Regression test for bug introduced by switching to miekg/dns
n := createTestNim()

n.controllerDNSCache([]byte(""), []byte("1.1"), "")
}

func TestWriteHostsFile(t *testing.T) {
n := createTestNim()

dnsResponses := []devicenetwork.DNSResponse{
{
IP: []byte{1, 1, 1, 1},
},
{
IP: []byte{1, 0, 0, 1},
},
}

dnsName := "one.one.one.one"

f, err := os.CreateTemp("", "writeHostsFile.*.etchosts")
if err != nil {
panic(err)
}
defer os.Remove(f.Name())
f.Close()

n.writeHostsFileToDestination(dnsResponses, []byte{}, []byte(dnsName), f.Name())

// reopen the file to be able to read what has been written by writeHostsFileToDestination; f.Seek(0, 0) unfortunately is not enough
f, err = os.Open(f.Name())
if err != nil {
panic(err)
}
content, err := io.ReadAll(f)
if err != nil {
panic(err)
}

for _, dnsResponse := range dnsResponses {
expectedContent := fmt.Sprintf("%s %s\n", dnsResponse.IP.String(), dnsName)
if !strings.Contains(string(content), expectedContent) {
t.Fatalf(
"writing to hosts file failed, expected: '%s', got: '%s'",
expectedContent,
content,
)
}
}
}
Loading