diff --git a/provider/coredns/coredns_test.go b/provider/coredns/coredns_test.go index 84f124ac50..8c9aa8f9d6 100644 --- a/provider/coredns/coredns_test.go +++ b/provider/coredns/coredns_test.go @@ -23,27 +23,30 @@ import ( "sigs.k8s.io/external-dns/endpoint" "sigs.k8s.io/external-dns/plan" + + "github.com/stretchr/testify/require" ) const defaultCoreDNSPrefix = "/skydns/" type fakeETCDClient struct { - services map[string]*Service + services map[string]Service } func (c fakeETCDClient) GetServices(prefix string) ([]*Service, error) { var result []*Service for key, value := range c.services { if strings.HasPrefix(key, prefix) { - value.Key = key - result = append(result, value) + valueCopy := value + valueCopy.Key = key + result = append(result, &valueCopy) } } return result, nil } func (c fakeETCDClient) SaveService(service *Service) error { - c.services[service.Key] = service + c.services[service.Key] = *service return nil } @@ -58,7 +61,7 @@ func TestAServiceTranslation(t *testing.T) { expectedRecordType := endpoint.RecordTypeA client := fakeETCDClient{ - map[string]*Service{ + map[string]Service{ "/skydns/com/example": {Host: expectedTarget}, }, } @@ -67,9 +70,7 @@ func TestAServiceTranslation(t *testing.T) { coreDNSPrefix: defaultCoreDNSPrefix, } endpoints, err := provider.Records(context.Background()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(endpoints) != 1 { t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) } @@ -90,7 +91,7 @@ func TestCNAMEServiceTranslation(t *testing.T) { expectedRecordType := endpoint.RecordTypeCNAME client := fakeETCDClient{ - map[string]*Service{ + map[string]Service{ "/skydns/com/example": {Host: expectedTarget}, }, } @@ -99,9 +100,7 @@ func TestCNAMEServiceTranslation(t *testing.T) { coreDNSPrefix: defaultCoreDNSPrefix, } endpoints, err := provider.Records(context.Background()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(endpoints) != 1 { t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) } @@ -122,7 +121,7 @@ func TestTXTServiceTranslation(t *testing.T) { expectedRecordType := endpoint.RecordTypeTXT client := fakeETCDClient{ - map[string]*Service{ + map[string]Service{ "/skydns/com/example": {Text: expectedTarget}, }, } @@ -131,9 +130,7 @@ func TestTXTServiceTranslation(t *testing.T) { coreDNSPrefix: defaultCoreDNSPrefix, } endpoints, err := provider.Records(context.Background()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(endpoints) != 1 { t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) } @@ -156,7 +153,7 @@ func TestAWithTXTServiceTranslation(t *testing.T) { expectedDNSName := "example.com" client := fakeETCDClient{ - map[string]*Service{ + map[string]Service{ "/skydns/com/example": {Host: "1.2.3.4", Text: "string"}, }, } @@ -165,9 +162,7 @@ func TestAWithTXTServiceTranslation(t *testing.T) { coreDNSPrefix: defaultCoreDNSPrefix, } endpoints, err := provider.Records(context.Background()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(endpoints) != len(expectedTargets) { t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) } @@ -198,7 +193,7 @@ func TestCNAMEWithTXTServiceTranslation(t *testing.T) { expectedDNSName := "example.com" client := fakeETCDClient{ - map[string]*Service{ + map[string]Service{ "/skydns/com/example": {Host: "example.net", Text: "string"}, }, } @@ -207,9 +202,7 @@ func TestCNAMEWithTXTServiceTranslation(t *testing.T) { coreDNSPrefix: defaultCoreDNSPrefix, } endpoints, err := provider.Records(context.Background()) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) if len(endpoints) != len(expectedTargets) { t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) } @@ -234,7 +227,7 @@ func TestCNAMEWithTXTServiceTranslation(t *testing.T) { func TestCoreDNSApplyChanges(t *testing.T) { client := fakeETCDClient{ - map[string]*Service{}, + map[string]Service{}, } coredns := coreDNSProvider{ client: client, @@ -248,11 +241,12 @@ func TestCoreDNSApplyChanges(t *testing.T) { endpoint.NewEndpoint("domain2.local", endpoint.RecordTypeCNAME, "site.local"), }, } - coredns.ApplyChanges(context.Background(), changes1) + err := coredns.ApplyChanges(context.Background(), changes1) + require.NoError(t, err) - expectedServices1 := map[string]*Service{ - "/skydns/local/domain1": {Host: "5.5.5.5", Text: "string1"}, - "/skydns/local/domain2": {Host: "site.local"}, + expectedServices1 := map[string][]*Service{ + "/skydns/local/domain1": {{Host: "5.5.5.5", Text: "string1"}}, + "/skydns/local/domain2": {{Host: "site.local"}}, } validateServices(client.services, expectedServices1, t, 1) @@ -270,12 +264,13 @@ func TestCoreDNSApplyChanges(t *testing.T) { changes2.UpdateOld = append(changes2.UpdateOld, ep) } } - applyServiceChanges(coredns, changes2) + err = applyServiceChanges(coredns, changes2) + require.NoError(t, err) - expectedServices2 := map[string]*Service{ - "/skydns/local/domain1": {Host: "6.6.6.6", Text: "string1"}, - "/skydns/local/domain2": {Host: "site.local"}, - "/skydns/local/domain3": {Host: "7.7.7.7"}, + expectedServices2 := map[string][]*Service{ + "/skydns/local/domain1": {{Host: "6.6.6.6", Text: "string1"}}, + "/skydns/local/domain2": {{Host: "site.local"}}, + "/skydns/local/domain3": {{Host: "7.7.7.7"}}, } validateServices(client.services, expectedServices2, t, 2) @@ -287,10 +282,11 @@ func TestCoreDNSApplyChanges(t *testing.T) { }, } - applyServiceChanges(coredns, changes3) + err = applyServiceChanges(coredns, changes3) + require.NoError(t, err) - expectedServices3 := map[string]*Service{ - "/skydns/local/domain2": {Host: "site.local"}, + expectedServices3 := map[string][]*Service{ + "/skydns/local/domain2": {{Host: "site.local"}}, } validateServices(client.services, expectedServices3, t, 3) @@ -302,18 +298,17 @@ func TestCoreDNSApplyChanges(t *testing.T) { endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "7.7.7.7"), }, } - coredns.ApplyChanges(context.Background(), changes4) + err = coredns.ApplyChanges(context.Background(), changes4) + require.NoError(t, err) - expectedServices4 := map[string]*Service{ - "/skydns/local/domain2": {Host: "site.local"}, - "/skydns/local/domain1/1": {Host: "5.5.5.5"}, - "/skydns/local/domain1/2": {Host: "6.6.6.6"}, - "/skydns/local/domain1": {Host: "7.7.7.7"}, + expectedServices4 := map[string][]*Service{ + "/skydns/local/domain2": {{Host: "site.local"}}, + "/skydns/local/domain1": {{Host: "5.5.5.5"}, {Host: "6.6.6.6"}, {Host: "7.7.7.7"}}, } - validateServices(client.services, expectedServices4, t, 1) + validateServices(client.services, expectedServices4, t, 4) } -func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) { +func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) error { ctx := context.Background() records, _ := provider.Records(ctx) for _, col := range [][]*endpoint.Endpoint{changes.Create, changes.UpdateNew, changes.Delete} { @@ -325,29 +320,38 @@ func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) { } } } - provider.ApplyChanges(ctx, changes) + return provider.ApplyChanges(ctx, changes) } -func validateServices(services, expectedServices map[string]*Service, t *testing.T, step int) { +func validateServices(services map[string]Service, expectedServices map[string][]*Service, t *testing.T, step int) { t.Helper() - if len(services) != len(expectedServices) { - t.Errorf("wrong number of records on step %d: %d != %d", step, len(services), len(expectedServices)) - } for key, value := range services { keyParts := strings.Split(key, "/") expectedKey := strings.Join(keyParts[:len(keyParts)-value.TargetStrip], "/") - expectedService := expectedServices[expectedKey] - if expectedService == nil { + expectedServiceEntries := expectedServices[expectedKey] + if expectedServiceEntries == nil { t.Errorf("unexpected service %s", key) continue } - delete(expectedServices, key) - if value.Host != expectedService.Host { - t.Errorf("wrong host for service %s: %s != %s on step %d", key, value.Host, expectedService.Host, step) + found := false + for i, expectedServiceEntry := range expectedServiceEntries { + if value.Host == expectedServiceEntry.Host && value.Text == expectedServiceEntry.Text { + expectedServiceEntries = append(expectedServiceEntries[:i], expectedServiceEntries[i+1:]...) + found = true + break + } } - if value.Text != expectedService.Text { - t.Errorf("wrong text for service %s: %s != %s on step %d", key, value.Text, expectedService.Text, step) + if !found { + t.Errorf("unexpected service %s: %s on step %d", key, value.Host, step) } + if len(expectedServiceEntries) == 0 { + delete(expectedServices, expectedKey) + } else { + expectedServices[expectedKey] = expectedServiceEntries + } + } + if len(expectedServices) != 0 { + t.Errorf("unmatched expected services: %+v on step %d", expectedServices, step) } }