Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update rfc2136 provider to split out changes per zone #4107

Merged
76 changes: 48 additions & 28 deletions provider/rfc2136/rfc2136.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,12 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes
for c, chunk := range chunkBy(changes.Create, r.batchChangeSize) {
log.Debugf("Processing batch %d of create changes", c)

m := new(dns.Msg)
m := make(map[string]*dns.Msg)
m["."] = new(dns.Msg) // Add the root zone
for _, z := range r.zoneNames {
z = dns.Fqdn(z)
m[z] = new(dns.Msg)
}
for _, ep := range chunk {
if !r.domainFilter.Match(ep.DNSName) {
log.Debugf("Skipping record %s because it was filtered out by the specified --domain-filter", ep.DNSName)
Expand All @@ -265,26 +270,33 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes

zone := findMsgZone(ep, r.zoneNames)
r.krb5Realm = strings.ToUpper(zone)
m.SetUpdate(zone)
m[zone].SetUpdate(zone)

r.AddRecord(m, ep)
r.AddRecord(m[zone], ep)
}

// only send if there are records available
if len(m.Ns) > 0 {
err := r.actions.SendMessage(m)
if err != nil {
log.Errorf("RFC2136 update failed: %v", err)
errors = append(errors, err)
continue
for _, z := range m {
if len(z.Ns) > 0 {
err := r.actions.SendMessage(z)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
err := r.actions.SendMessage(z)
if err != nil {
if err := r.actions.SendMessage(z); err != nil {

log.Errorf("RFC2136 update failed: %v", err)
Copy link
Contributor

@mloiseleur mloiseleur Jan 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log.Errorf("RFC2136 update failed: %v", err)
log.Errorf("RFC2136 create record failed: %v", err)

errors = append(errors, err)
continue
}
}
}
}

for c, chunk := range chunkBy(changes.UpdateNew, r.batchChangeSize) {
log.Debugf("Processing batch %d of update changes", c)

m := new(dns.Msg)
m := make(map[string]*dns.Msg)
m["."] = new(dns.Msg) // Add the root zone
for _, z := range r.zoneNames {
z = dns.Fqdn(z)
m[z] = new(dns.Msg)
}

for i, ep := range chunk {
if !r.domainFilter.Match(ep.DNSName) {
Expand All @@ -294,27 +306,33 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes

zone := findMsgZone(ep, r.zoneNames)
r.krb5Realm = strings.ToUpper(zone)
m.SetUpdate(zone)
m[zone].SetUpdate(zone)

r.UpdateRecord(m, changes.UpdateOld[i], ep)
r.UpdateRecord(m[zone], changes.UpdateOld[i], ep)
}

// only send if there are records available
if len(m.Ns) > 0 {
err := r.actions.SendMessage(m)
if err != nil {
log.Errorf("RFC2136 update failed: %v", err)
errors = append(errors, err)
continue
for _, z := range m {
if len(z.Ns) > 0 {
err := r.actions.SendMessage(z)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
err := r.actions.SendMessage(z)
if err != nil {
if err := r.actions.SendMessage(z); err != nil {

log.Errorf("RFC2136 update failed: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log.Errorf("RFC2136 update failed: %v", err)
log.Errorf("RFC2136 update record failed: %v", err)

errors = append(errors, err)
continue
}
}
}
}

for c, chunk := range chunkBy(changes.Delete, r.batchChangeSize) {
log.Debugf("Processing batch %d of delete changes", c)

m := new(dns.Msg)

m := make(map[string]*dns.Msg)
m["."] = new(dns.Msg) // Add the root zone
for _, z := range r.zoneNames {
z = dns.Fqdn(z)
m[z] = new(dns.Msg)
}
for _, ep := range chunk {
if !r.domainFilter.Match(ep.DNSName) {
log.Debugf("Skipping record %s because it was filtered out by the specified --domain-filter", ep.DNSName)
Expand All @@ -323,18 +341,20 @@ func (r rfc2136Provider) ApplyChanges(ctx context.Context, changes *plan.Changes

zone := findMsgZone(ep, r.zoneNames)
r.krb5Realm = strings.ToUpper(zone)
m.SetUpdate(zone)
m[zone].SetUpdate(zone)

r.RemoveRecord(m, ep)
r.RemoveRecord(m[zone], ep)
}

// only send if there are records available
if len(m.Ns) > 0 {
err := r.actions.SendMessage(m)
if err != nil {
log.Errorf("RFC2136 update failed: %v", err)
errors = append(errors, err)
continue
for _, z := range m {
if len(z.Ns) > 0 {
err := r.actions.SendMessage(z)
if err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
err := r.actions.SendMessage(z)
if err != nil {
if err := r.actions.SendMessage(z); err != nil {

log.Errorf("RFC2136 update failed: %v", err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
log.Errorf("RFC2136 update failed: %v", err)
log.Errorf("RFC2136 delete record failed: %v", err)

errors = append(errors, err)
continue
}
}
}
}
Expand Down
206 changes: 205 additions & 1 deletion provider/rfc2136/rfc2136_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package rfc2136
import (
"context"
"fmt"
"regexp"
"sort"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -46,8 +48,23 @@ func newStub() *rfc2136Stub {
}
}

func getSortedChanges(msgs []*dns.Msg) []string {
r := []string{}
for _, d := range msgs {
// only care about section after the ZONE SECTION: as the id: needs stripped out in order to sort and grantee the order when sorting
r = append(r, strings.Split(d.String(), "ZONE SECTION:")[1])
}
sort.Strings(r)
return r
}

func (r *rfc2136Stub) SendMessage(msg *dns.Msg) error {
log.Info(msg.String())
zone := extractZoneFromMessage(msg.String())
// Make sure the zone starts with . to make sure HasSuffix does not match forbar.com for zone bar.com
if !strings.HasPrefix(zone, ".") {
zone = "." + zone
}
log.Infof("zone=%s", zone)
lines := extractUpdateSectionFromMessage(msg)
for _, line := range lines {
// break at first empty line
Expand All @@ -57,6 +74,12 @@ func (r *rfc2136Stub) SendMessage(msg *dns.Msg) error {

line = strings.Replace(line, "\t", " ", -1)
log.Info(line)
record := strings.Split(line, " ")[0]
if !strings.HasSuffix(record, zone) {
err := fmt.Errorf("Message contains updates outside of it's zone. zone=%v record=%v", zone, record)
log.Error(err)
return err
}

if strings.Contains(line, " NONE ") {
r.updateMsgs = append(r.updateMsgs, msg)
Expand Down Expand Up @@ -98,12 +121,28 @@ func createRfc2136StubProvider(stub *rfc2136Stub) (provider.Provider, error) {
return NewRfc2136Provider("", 0, nil, false, "key", "secret", "hmac-sha512", true, endpoint.DomainFilter{}, false, 300*time.Second, false, "", "", "", 50, stub)
}

func createRfc2136StubProviderWithZones(stub *rfc2136Stub) (provider.Provider, error) {
zones := []string{"foo.com", "foobar.com"}
return NewRfc2136Provider("", 0, zones, false, "key", "secret", "hmac-sha512", true, endpoint.DomainFilter{}, false, 300*time.Second, false, "", "", "", 50, stub)
}

func createRfc2136StubProviderWithZonesFilters(stub *rfc2136Stub) (provider.Provider, error) {
zones := []string{"foo.com", "foobar.com"}
return NewRfc2136Provider("", 0, zones, false, "key", "secret", "hmac-sha512", true, endpoint.DomainFilter{Filters: zones}, false, 300*time.Second, false, "", "", "", 50, stub)
}

func extractUpdateSectionFromMessage(msg fmt.Stringer) []string {
const searchPattern = "UPDATE SECTION:"
updateSectionOffset := strings.Index(msg.String(), searchPattern)
return strings.Split(strings.TrimSpace(msg.String()[updateSectionOffset+len(searchPattern):]), "\n")
}

func extractZoneFromMessage(msg string) string {
re := regexp.MustCompile(`ZONE SECTION:\n;(?P<ZONE>[\.,\-,\w,\d]+)\t`)
matches := re.FindStringSubmatch(msg)
return matches[re.SubexpIndex("ZONE")]
}

// TestRfc2136GetRecordsMultipleTargets simulates a single record with multiple targets.
func TestRfc2136GetRecordsMultipleTargets(t *testing.T) {
stub := newStub()
Expand Down Expand Up @@ -154,6 +193,32 @@ func TestRfc2136GetRecords(t *testing.T) {
assert.True(t, contains(recs, "v2.foo.com"))
}

// Make sure the test version of SendMessage raises an error
// if a zone update ever contains records outside of it's zone
// as the TestRfc2136ApplyChanges tests all assume this
func TestRfc2136SendMessage(t *testing.T) {
stub := newStub()

m := new(dns.Msg)
m.SetUpdate("foo.com.")
rr, err := dns.NewRR(fmt.Sprintf("%s %d %s %s", "v1.foo.com.", 0, "A", "1.2.3.4"))
m.Insert([]dns.RR{rr})

err = stub.SendMessage(m)
assert.NoError(t, err)

rr, err = dns.NewRR(fmt.Sprintf("%s %d %s %s", "v1.bar.com.", 0, "A", "1.2.3.4"))
m.Insert([]dns.RR{rr})

err = stub.SendMessage(m)
assert.Error(t, err)

m.SetUpdate(".")
err = stub.SendMessage(m)
assert.NoError(t, err)
}

// These tests are use the . root zone with no filters
func TestRfc2136ApplyChanges(t *testing.T) {
stub := newStub()
provider, err := createRfc2136StubProvider(stub)
Expand Down Expand Up @@ -210,6 +275,145 @@ func TestRfc2136ApplyChanges(t *testing.T) {
assert.True(t, strings.Contains(stub.updateMsgs[1].String(), "v2.foobar.com"))
}

// These tests all use the foo.com and foobar.com zones with no filters
// createMsgs and updateMsgs need sorted when are are used
func TestRfc2136ApplyChangesWithZones(t *testing.T) {
stub := newStub()
provider, err := createRfc2136StubProviderWithZones(stub)
assert.NoError(t, err)

p := &plan.Changes{
Create: []*endpoint.Endpoint{
{
DNSName: "v1.foo.com",
RecordType: "A",
Targets: []string{"1.2.3.4"},
RecordTTL: endpoint.TTL(400),
},
{
DNSName: "v1.foobar.com",
RecordType: "TXT",
Targets: []string{"boom"},
},
{
DNSName: "ns.foobar.com",
RecordType: "NS",
Targets: []string{"boom"},
},
},
Delete: []*endpoint.Endpoint{
{
DNSName: "v2.foo.com",
RecordType: "A",
Targets: []string{"1.2.3.4"},
},
{
DNSName: "v2.foobar.com",
RecordType: "TXT",
Targets: []string{"boom2"},
},
},
}

err = provider.ApplyChanges(context.Background(), p)
assert.NoError(t, err)

assert.Equal(t, 3, len(stub.createMsgs))
createMsgs := getSortedChanges(stub.createMsgs)
assert.Equal(t, 3, len(createMsgs))

assert.True(t, strings.Contains(createMsgs[0], "v1.foo.com"))
assert.True(t, strings.Contains(createMsgs[0], "1.2.3.4"))

assert.True(t, strings.Contains(createMsgs[1], "v1.foobar.com"))
assert.True(t, strings.Contains(createMsgs[1], "boom"))

assert.True(t, strings.Contains(createMsgs[2], "ns.foobar.com"))
assert.True(t, strings.Contains(createMsgs[2], "boom"))

assert.Equal(t, 2, len(stub.updateMsgs))
updateMsgs := getSortedChanges(stub.updateMsgs)
assert.Equal(t, 2, len(updateMsgs))

assert.True(t, strings.Contains(updateMsgs[0], "v2.foo.com"))
assert.True(t, strings.Contains(updateMsgs[1], "v2.foobar.com"))
}

// These tests use the foo.com and foobar.com zones and with filters set to both zones
// createMsgs and updateMsgs need sorted when are are used
func TestRfc2136ApplyChangesWithZonesFilters(t *testing.T) {
stub := newStub()
provider, err := createRfc2136StubProviderWithZonesFilters(stub)
assert.NoError(t, err)

p := &plan.Changes{
Create: []*endpoint.Endpoint{
{
DNSName: "v1.foo.com",
RecordType: "A",
Targets: []string{"1.2.3.4"},
RecordTTL: endpoint.TTL(400),
},
{
DNSName: "v1.foobar.com",
RecordType: "TXT",
Targets: []string{"boom"},
},
{
DNSName: "ns.foobar.com",
RecordType: "NS",
Targets: []string{"boom"},
},
{
DNSName: "filtered-out.foo.bar",
RecordType: "A",
Targets: []string{"1.2.3.4"},
RecordTTL: endpoint.TTL(400),
},
},
Delete: []*endpoint.Endpoint{
{
DNSName: "v2.foo.com",
RecordType: "A",
Targets: []string{"1.2.3.4"},
},
{
DNSName: "v2.foobar.com",
RecordType: "TXT",
Targets: []string{"boom2"},
},
},
}

err = provider.ApplyChanges(context.Background(), p)
assert.NoError(t, err)

assert.Equal(t, 3, len(stub.createMsgs))
createMsgs := getSortedChanges(stub.createMsgs)
assert.Equal(t, 3, len(createMsgs))

assert.True(t, strings.Contains(createMsgs[0], "v1.foo.com"))
assert.True(t, strings.Contains(createMsgs[0], "1.2.3.4"))

assert.True(t, strings.Contains(createMsgs[1], "v1.foobar.com"))
assert.True(t, strings.Contains(createMsgs[1], "boom"))

assert.True(t, strings.Contains(createMsgs[2], "ns.foobar.com"))
assert.True(t, strings.Contains(createMsgs[2], "boom"))

for _, s := range createMsgs {
assert.False(t, strings.Contains(s, "filtered-out.foo.bar"))
}

assert.Equal(t, 2, len(stub.updateMsgs))
updateMsgs := getSortedChanges(stub.updateMsgs)
assert.Equal(t, 2, len(updateMsgs))

assert.True(t, strings.Contains(updateMsgs[0], "v2.foo.com"))
assert.True(t, strings.Contains(updateMsgs[1], "v2.foobar.com"))

}

func TestRfc2136ApplyChangesWithDifferentTTLs(t *testing.T) {
stub := newStub()

Expand Down
Loading