Skip to content

Commit

Permalink
feat(azure): add zone name filter for Azure Private DNS
Browse files Browse the repository at this point in the history
  • Loading branch information
khuedoan committed Mar 28, 2024
1 parent 61da7cc commit 425dea4
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
2 changes: 1 addition & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ func main() {
case "azure-dns", "azure":
p, err = azure.NewAzureProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.DryRun)
case "azure-private-dns":
p, err = azure.NewAzurePrivateDNSProvider(cfg.AzureConfigFile, domainFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.DryRun)
p, err = azure.NewAzurePrivateDNSProvider(cfg.AzureConfigFile, domainFilter, zoneNameFilter, zoneIDFilter, cfg.AzureSubscriptionID, cfg.AzureResourceGroup, cfg.AzureUserAssignedIdentityClientID, cfg.DryRun)
case "bluecat":
p, err = bluecat.NewBluecatProvider(cfg.BluecatConfigFile, cfg.BluecatDNSConfiguration, cfg.BluecatDNSServerName, cfg.BluecatDNSDeployType, cfg.BluecatDNSView, cfg.BluecatGatewayHost, cfg.BluecatRootZone, cfg.TXTPrefix, cfg.TXTSuffix, domainFilter, zoneIDFilter, cfg.DryRun, cfg.BluecatSkipTLSVerify)
case "vinyldns":
Expand Down
19 changes: 18 additions & 1 deletion provider/azure/azure_private_dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type PrivateRecordSetsClient interface {
type AzurePrivateDNSProvider struct {
provider.BaseProvider
domainFilter endpoint.DomainFilter
zoneNameFilter endpoint.DomainFilter
zoneIDFilter provider.ZoneIDFilter
dryRun bool
resourceGroup string
Expand All @@ -59,7 +60,7 @@ type AzurePrivateDNSProvider struct {
// NewAzurePrivateDNSProvider creates a new Azure Private DNS provider.
//
// Returns the provider or an error if a provider could not be created.
func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, dryRun bool) (*AzurePrivateDNSProvider, error) {
func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, dryRun bool) (*AzurePrivateDNSProvider, error) {
cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID)
if err != nil {
return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err)
Expand All @@ -79,6 +80,7 @@ func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainF
}
return &AzurePrivateDNSProvider{
domainFilter: domainFilter,
zoneNameFilter: zoneNameFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
resourceGroup: cfg.ResourceGroup,
Expand Down Expand Up @@ -122,6 +124,10 @@ func (p *AzurePrivateDNSProvider) Records(ctx context.Context) (endpoints []*end
}
name = formatAzureDNSName(*recordSet.Name, *zone.Name)

if len(p.zoneNameFilter.Filters) > 0 && !p.domainFilter.Match(name) {
log.Debugf("Skipping return of record %s because it was filtered out by the specified --domain-filter", name)
continue
}
targets := extractAzurePrivateDNSTargets(recordSet)
if len(targets) == 0 {
log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType)
Expand Down Expand Up @@ -183,6 +189,9 @@ func (p *AzurePrivateDNSProvider) zones(ctx context.Context) ([]privatedns.Priva

if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) {
zones = append(zones, *zone)
} else if zone.Name != nil && len(p.zoneNameFilter.Filters) > 0 && p.zoneNameFilter.Match(*zone.Name) {
// Handle zoneNameFilter
zones = append(zones, *zone)
}
}
}
Expand Down Expand Up @@ -236,6 +245,10 @@ func (p *AzurePrivateDNSProvider) deleteRecords(ctx context.Context, deleted azu
for zone, endpoints := range deleted {
for _, ep := range endpoints {
name := p.recordSetNameForZone(zone, ep)
if !p.domainFilter.Match(ep.DNSName) {
log.Debugf("Skipping deletion of record %s because it was filtered out by the specified --domain-filter", ep.DNSName)
continue
}
if p.dryRun {
log.Infof("Would delete %s record named '%s' for Azure Private DNS zone '%s'.", ep.RecordType, name, zone)
} else {
Expand All @@ -259,6 +272,10 @@ func (p *AzurePrivateDNSProvider) updateRecords(ctx context.Context, updated azu
for zone, endpoints := range updated {
for _, ep := range endpoints {
name := p.recordSetNameForZone(zone, ep)
if !p.domainFilter.Match(ep.DNSName) {
log.Debugf("Skipping update of record %s because it was filtered out by the specified --domain-filter", ep.DNSName)
continue
}
if p.dryRun {
log.Infof(
"Would update %s record named '%s' to '%s' for Azure Private DNS zone '%s'.",
Expand Down
12 changes: 7 additions & 5 deletions provider/azure/azure_privatedns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,16 @@ func createPrivateMockRecordSetMultiWithTTL(name, recordType string, ttl int64,
}

// newMockedAzurePrivateDNSProvider creates an AzureProvider comprising the mocked clients for zones and recordsets
func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones []*privatedns.PrivateZone, recordSets []*privatedns.RecordSet) (*AzurePrivateDNSProvider, error) {
func newMockedAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, zones []*privatedns.PrivateZone, recordSets []*privatedns.RecordSet) (*AzurePrivateDNSProvider, error) {
zonesClient := newMockPrivateZonesClient(zones)
recordSetsClient := newMockPrivateRecordSectsClient(recordSets)
return newAzurePrivateDNSProvider(domainFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient), nil
return newAzurePrivateDNSProvider(domainFilter, zoneNameFilter, zoneIDFilter, dryRun, resourceGroup, &zonesClient, &recordSetsClient), nil
}

func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, privateZonesClient PrivateZonesClient, privateRecordsClient PrivateRecordSetsClient) *AzurePrivateDNSProvider {
func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, resourceGroup string, privateZonesClient PrivateZonesClient, privateRecordsClient PrivateRecordSetsClient) *AzurePrivateDNSProvider {
return &AzurePrivateDNSProvider{
domainFilter: domainFilter,
zoneNameFilter: zoneNameFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
resourceGroup: resourceGroup,
Expand All @@ -242,7 +243,7 @@ func newAzurePrivateDNSProvider(domainFilter endpoint.DomainFilter, zoneIDFilter
}

func TestAzurePrivateDNSRecord(t *testing.T) {
provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s",
provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s",
[]*privatedns.PrivateZone{
createMockPrivateZone("example.com", "/privateDnsZones/example.com"),
},
Expand Down Expand Up @@ -281,7 +282,7 @@ func TestAzurePrivateDNSRecord(t *testing.T) {
}

func TestAzurePrivateDNSMultiRecord(t *testing.T) {
provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, "k8s",
provider, err := newMockedAzurePrivateDNSProvider(endpoint.NewDomainFilter([]string{"example.com"}), endpoint.NewDomainFilter([]string{}), provider.NewZoneIDFilter([]string{""}), true, "k8s",
[]*privatedns.PrivateZone{
createMockPrivateZone("example.com", "/privateDnsZones/example.com"),
},
Expand Down Expand Up @@ -369,6 +370,7 @@ func testAzurePrivateDNSApplyChangesInternal(t *testing.T, dryRun bool, client P
zonesClient := newMockPrivateZonesClient(zones)

provider := newAzurePrivateDNSProvider(
endpoint.NewDomainFilter([]string{""}),
endpoint.NewDomainFilter([]string{""}),
provider.NewZoneIDFilter([]string{""}),
dryRun,
Expand Down

0 comments on commit 425dea4

Please sign in to comment.