Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only update PiHole entries when they have actually changed #3297

Merged
merged 4 commits into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions provider/pihole/pihole.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,25 +74,40 @@ func (p *PiholeProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, err

// ApplyChanges implements Provider, syncing desired state with the Pi-hole server Local DNS.
func (p *PiholeProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
// Handle deletions first - there are no endpoints for updating in place.
// Handle pure deletes first.
for _, ep := range changes.Delete {
if err := p.api.deleteRecord(ctx, ep); err != nil {
return err
}
}

// Handle updated state - there are no endpoints for updating in place.
updateNew := make(map[string]*endpoint.Endpoint)
for _, ep := range changes.UpdateNew {
updateNew[ep.DNSName] = ep
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if changes.UpdateNew has multiple endpoints with the same DNSName? Shouldn't you put the RecordType in the map key?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! PiHole doesn't allow two records of the same RecordType with the same DNS name, but it does allow two records with the same DNS name if the record types are different. Fixed

}

for _, ep := range changes.UpdateOld {
// Check if this existing entry has an exact match for an updated entry, and skip it if so.
sfleener marked this conversation as resolved.
Show resolved Hide resolved
if newRecord := updateNew[ep.DNSName]; newRecord != nil {
// PiHole only has a single target and a record type, no need to compare other fields.
sfleener marked this conversation as resolved.
Show resolved Hide resolved
if newRecord.Targets.String() == ep.Targets.String() && newRecord.RecordType == ep.RecordType {
sfleener marked this conversation as resolved.
Show resolved Hide resolved
delete(updateNew, ep.DNSName)
continue
}
}
if err := p.api.deleteRecord(ctx, ep); err != nil {
return err
}
}

// Handle desired state
// Handle pure creates before applying new updated state.
for _, ep := range changes.Create {
if err := p.api.createRecord(ctx, ep); err != nil {
return err
}
}
for _, ep := range changes.UpdateNew {
for _, ep := range updateNew {
if err := p.api.createRecord(ctx, ep); err != nil {
return err
}
Expand Down
50 changes: 49 additions & 1 deletion provider/pihole/pihole_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

type testPiholeClient struct {
endpoints []*endpoint.Endpoint
requests *requestTracker
}

func (t *testPiholeClient) listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) {
Expand All @@ -40,6 +41,7 @@ func (t *testPiholeClient) listRecords(ctx context.Context, rtype string) ([]*en

func (t *testPiholeClient) createRecord(ctx context.Context, ep *endpoint.Endpoint) error {
t.endpoints = append(t.endpoints, ep)
t.requests.createRequests += 1
return nil
}

Expand All @@ -51,9 +53,20 @@ func (t *testPiholeClient) deleteRecord(ctx context.Context, ep *endpoint.Endpoi
}
}
t.endpoints = newEPs
t.requests.deleteRequests += 1
return nil
}

type requestTracker struct {
createRequests int
deleteRequests int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better unit tests would explicitly specify the expected requests, verifying the dnsname, recordtype, and target.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the tests to compare the actual and expected requests instead of just the counts

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new version of the tests is asserting that the three new records are created in order, which isn't actually a requirement.

I would have written this to put the expected calls in the tracker. The client would then find the request in the tracker and remove it, asserting that the request was found in the tracker. At the end, the test would assert that there were no remaining expected calls in the tracker.

}

func (r *requestTracker) clear() {
r.createRequests = 0
r.deleteRequests = 0
}

func TestNewPiholeProvider(t *testing.T) {
// Test invalid configuration
_, err := NewPiholeProvider(PiholeConfig{})
Expand All @@ -68,8 +81,9 @@ func TestNewPiholeProvider(t *testing.T) {
}

func TestProvider(t *testing.T) {
requests := requestTracker{}
p := &PiholeProvider{
api: &testPiholeClient{},
api: &testPiholeClient{endpoints: make([]*endpoint.Endpoint, 0), requests: &requests},
}

records, err := p.Records(context.Background())
Expand Down Expand Up @@ -113,6 +127,12 @@ func TestProvider(t *testing.T) {
if len(newRecords) != 3 {
t.Fatal("Expected list of 3 records, got:", records)
}
if requests.createRequests != 3 {
t.Fatal("Expected 3 create requests, got:", requests.createRequests)
}
if requests.deleteRequests != 0 {
t.Fatal("Expected no delete requests, got:", requests.deleteRequests)
}

for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName {
Expand All @@ -123,6 +143,8 @@ func TestProvider(t *testing.T) {
}
}

requests.clear()

// Test delete a record

records = []*endpoint.Endpoint{
Expand All @@ -148,6 +170,12 @@ func TestProvider(t *testing.T) {
}); err != nil {
t.Fatal(err)
}
if requests.createRequests != 0 {
t.Fatal("Expected no create requests, got:", requests.createRequests)
}
if requests.deleteRequests != 1 {
t.Fatal("Expected 1 delete request, got:", requests.deleteRequests)
}

// Test records are updated
newRecords, err = p.Records(context.Background())
Expand All @@ -167,6 +195,8 @@ func TestProvider(t *testing.T) {
}
}

requests.clear()

// Test update a record

records = []*endpoint.Endpoint{
Expand All @@ -183,13 +213,23 @@ func TestProvider(t *testing.T) {
}
if err := p.ApplyChanges(context.Background(), &plan.Changes{
UpdateOld: []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"192.168.1.2"},
RecordType: endpoint.RecordTypeA,
},
},
UpdateNew: []*endpoint.Endpoint{
{
DNSName: "test1.example.com",
Targets: []string{"192.168.1.1"},
RecordType: endpoint.RecordTypeA,
},
{
DNSName: "test2.example.com",
Targets: []string{"10.0.0.1"},
Expand All @@ -208,6 +248,12 @@ func TestProvider(t *testing.T) {
if len(newRecords) != 2 {
t.Fatal("Expected list of 2 records, got:", records)
}
if requests.createRequests != 1 {
t.Fatal("Expected 1 create request, got:", requests.createRequests)
}
if requests.deleteRequests != 1 {
t.Fatal("Expected 1 delete request, got:", requests.deleteRequests)
}

for idx, record := range records {
if newRecords[idx].DNSName != record.DNSName {
Expand All @@ -217,4 +263,6 @@ func TestProvider(t *testing.T) {
t.Error("Targets malformed on retrieval, got:", newRecords[idx].Targets, "expected:", record.Targets)
}
}

requests.clear()
}