Skip to content

Commit

Permalink
Simplify implementation of inmemory provider
Browse files Browse the repository at this point in the history
  • Loading branch information
johngmyers committed Jun 5, 2023
1 parent 6cc5884 commit 10fde9c
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 372 deletions.
122 changes: 34 additions & 88 deletions provider/inmemory/inmemory.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"strings"

log "github.com/sirupsen/logrus"
"k8s.io/apimachinery/pkg/util/sets"

"sigs.k8s.io/external-dns/endpoint"
"sigs.k8s.io/external-dns/plan"
Expand Down Expand Up @@ -132,11 +133,7 @@ func (im *InMemoryProvider) Records(ctx context.Context) ([]*endpoint.Endpoint,
return nil, err
}

for _, record := range records {
ep := endpoint.NewEndpoint(record.Name, record.Type, record.Target).WithSetIdentifier(record.SetIdentifier)
ep.Labels = record.Labels
endpoints = append(endpoints, ep)
}
endpoints = append(endpoints, copyEndpoints(records)...)
}

return endpoints, nil
Expand Down Expand Up @@ -187,11 +184,11 @@ func (im *InMemoryProvider) ApplyChanges(ctx context.Context, changes *plan.Chan
}

for zoneID := range perZoneChanges {
change := &inMemoryChange{
Create: convertToInMemoryRecord(perZoneChanges[zoneID].Create),
UpdateNew: convertToInMemoryRecord(perZoneChanges[zoneID].UpdateNew),
UpdateOld: convertToInMemoryRecord(perZoneChanges[zoneID].UpdateOld),
Delete: convertToInMemoryRecord(perZoneChanges[zoneID].Delete),
change := &plan.Changes{
Create: perZoneChanges[zoneID].Create,
UpdateNew: perZoneChanges[zoneID].UpdateNew,
UpdateOld: perZoneChanges[zoneID].UpdateOld,
Delete: perZoneChanges[zoneID].Delete,
}
err := im.client.ApplyChanges(ctx, zoneID, change)
if err != nil {
Expand All @@ -202,16 +199,15 @@ func (im *InMemoryProvider) ApplyChanges(ctx context.Context, changes *plan.Chan
return nil
}

func convertToInMemoryRecord(endpoints []*endpoint.Endpoint) []*inMemoryRecord {
records := []*inMemoryRecord{}
func copyEndpoints(endpoints []*endpoint.Endpoint) []*endpoint.Endpoint {
records := make([]*endpoint.Endpoint, 0, len(endpoints))
for _, ep := range endpoints {
records = append(records, &inMemoryRecord{
Type: ep.RecordType,
Name: ep.DNSName,
Target: ep.Targets[0],
SetIdentifier: ep.SetIdentifier,
Labels: ep.Labels,
})
newEp := endpoint.NewEndpointWithTTL(ep.DNSName, ep.RecordType, ep.RecordTTL, ep.Targets...).WithSetIdentifier(ep.SetIdentifier)
newEp.Labels = endpoint.NewLabels()
for k, v := range ep.Labels {
newEp.Labels[k] = v
}
records = append(records, newEp)
}
return records
}
Expand Down Expand Up @@ -244,26 +240,7 @@ func (f *filter) EndpointZoneID(endpoint *endpoint.Endpoint, zones map[string]st
return matchZoneID
}

// inMemoryRecord - record stored in memory
// Type - type of record
// Name - DNS name assigned to the record
// Target - target of the record
type inMemoryRecord struct {
Type string
SetIdentifier string
Name string
Target string
Labels endpoint.Labels
}

type zone map[string][]*inMemoryRecord

type inMemoryChange struct {
Create []*inMemoryRecord
UpdateNew []*inMemoryRecord
UpdateOld []*inMemoryRecord
Delete []*inMemoryRecord
}
type zone map[endpoint.EndpointKey]*endpoint.Endpoint

type inMemoryClient struct {
zones map[string]zone
Expand All @@ -273,14 +250,14 @@ func newInMemoryClient() *inMemoryClient {
return &inMemoryClient{map[string]zone{}}
}

func (c *inMemoryClient) Records(zone string) ([]*inMemoryRecord, error) {
func (c *inMemoryClient) Records(zone string) ([]*endpoint.Endpoint, error) {
if _, ok := c.zones[zone]; !ok {
return nil, ErrZoneNotFound
}

records := []*inMemoryRecord{}
var records []*endpoint.Endpoint
for _, rec := range c.zones[zone] {
records = append(records, rec...)
records = append(records, rec)
}
return records, nil
}
Expand All @@ -297,87 +274,65 @@ func (c *inMemoryClient) CreateZone(zone string) error {
if _, ok := c.zones[zone]; ok {
return ErrZoneAlreadyExists
}
c.zones[zone] = map[string][]*inMemoryRecord{}
c.zones[zone] = map[endpoint.EndpointKey]*endpoint.Endpoint{}

return nil
}

func (c *inMemoryClient) ApplyChanges(ctx context.Context, zoneID string, changes *inMemoryChange) error {
func (c *inMemoryClient) ApplyChanges(ctx context.Context, zoneID string, changes *plan.Changes) error {
if err := c.validateChangeBatch(zoneID, changes); err != nil {
return err
}
for _, newEndpoint := range changes.Create {
if _, ok := c.zones[zoneID][newEndpoint.Name]; !ok {
c.zones[zoneID][newEndpoint.Name] = make([]*inMemoryRecord, 0)
}
c.zones[zoneID][newEndpoint.Name] = append(c.zones[zoneID][newEndpoint.Name], newEndpoint)
c.zones[zoneID][newEndpoint.Key()] = newEndpoint
}
for _, updateEndpoint := range changes.UpdateNew {
for _, rec := range c.zones[zoneID][updateEndpoint.Name] {
if rec.Type == updateEndpoint.Type {
rec.Target = updateEndpoint.Target
break
}
}
c.zones[zoneID][updateEndpoint.Key()] = updateEndpoint
}
for _, deleteEndpoint := range changes.Delete {
newSet := make([]*inMemoryRecord, 0)
for _, rec := range c.zones[zoneID][deleteEndpoint.Name] {
if rec.Type != deleteEndpoint.Type {
newSet = append(newSet, rec)
}
}
c.zones[zoneID][deleteEndpoint.Name] = newSet
delete(c.zones[zoneID], deleteEndpoint.Key())
}
return nil
}

func (c *inMemoryClient) updateMesh(mesh map[string]map[string]map[string]bool, record *inMemoryRecord) error {
if _, exists := mesh[record.Name]; exists {
if _, exists := mesh[record.Name][record.Type]; exists {
if mesh[record.Name][record.Type][record.SetIdentifier] {
return ErrDuplicateRecordFound
}
mesh[record.Name][record.Type][record.SetIdentifier] = true
return nil
}
mesh[record.Name][record.Type] = map[string]bool{record.SetIdentifier: true}
return nil
func (c *inMemoryClient) updateMesh(mesh sets.Set[endpoint.EndpointKey], record *endpoint.Endpoint) error {
if mesh.Has(record.Key()) {
return ErrDuplicateRecordFound
}
mesh[record.Name] = map[string]map[string]bool{record.Type: {record.SetIdentifier: true}}
mesh.Insert(record.Key())
return nil
}

// validateChangeBatch validates that the changes passed to InMemory DNS provider is valid
func (c *inMemoryClient) validateChangeBatch(zone string, changes *inMemoryChange) error {
func (c *inMemoryClient) validateChangeBatch(zone string, changes *plan.Changes) error {
curZone, ok := c.zones[zone]
if !ok {
return ErrZoneNotFound
}
mesh := map[string]map[string]map[string]bool{}
mesh := sets.New[endpoint.EndpointKey]()
for _, newEndpoint := range changes.Create {
if c.findByTypeAndSetIdentifier(newEndpoint.Type, newEndpoint.SetIdentifier, curZone[newEndpoint.Name]) != nil {
if _, exists := curZone[newEndpoint.Key()]; exists {
return ErrRecordAlreadyExists
}
if err := c.updateMesh(mesh, newEndpoint); err != nil {
return err
}
}
for _, updateEndpoint := range changes.UpdateNew {
if c.findByTypeAndSetIdentifier(updateEndpoint.Type, updateEndpoint.SetIdentifier, curZone[updateEndpoint.Name]) == nil {
if _, exists := curZone[updateEndpoint.Key()]; !exists {
return ErrRecordNotFound
}
if err := c.updateMesh(mesh, updateEndpoint); err != nil {
return err
}
}
for _, updateOldEndpoint := range changes.UpdateOld {
if rec := c.findByTypeAndSetIdentifier(updateOldEndpoint.Type, updateOldEndpoint.SetIdentifier, curZone[updateOldEndpoint.Name]); rec == nil || rec.Target != updateOldEndpoint.Target {
if rec, exists := curZone[updateOldEndpoint.Key()]; !exists || rec.Targets[0] != updateOldEndpoint.Targets[0] {
return ErrRecordNotFound
}
}
for _, deleteEndpoint := range changes.Delete {
if rec := c.findByTypeAndSetIdentifier(deleteEndpoint.Type, deleteEndpoint.SetIdentifier, curZone[deleteEndpoint.Name]); rec == nil || rec.Target != deleteEndpoint.Target {
if rec, exists := curZone[deleteEndpoint.Key()]; !exists || rec.Targets[0] != deleteEndpoint.Targets[0] {
return ErrRecordNotFound
}
if err := c.updateMesh(mesh, deleteEndpoint); err != nil {
Expand All @@ -386,12 +341,3 @@ func (c *inMemoryClient) validateChangeBatch(zone string, changes *inMemoryChang
}
return nil
}

func (c *inMemoryClient) findByTypeAndSetIdentifier(recordType, setIdentifier string, records []*inMemoryRecord) *inMemoryRecord {
for _, record := range records {
if record.Type == recordType && record.SetIdentifier == setIdentifier {
return record
}
}
return nil
}
Loading

0 comments on commit 10fde9c

Please sign in to comment.