Skip to content

Commit

Permalink
fix(infoblox): set view and zone query parameters
Browse files Browse the repository at this point in the history
The `zone` and `view` search query parameters are not included in record search requests. This can result in more records returned in the query than necessary.

For example if more than one zone is returned, rather
than issuing a query on each respective zone it is
searching all zones on each iteration.
  • Loading branch information
cronik committed May 4, 2023
1 parent 510eb95 commit 0f2d419
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 42 deletions.
52 changes: 18 additions & 34 deletions provider/infoblox/infoblox.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,18 @@ func NewInfobloxProvider(ibStartupCfg StartupConfig) (*ProviderConfig, error) {
return providerCfg, nil
}

func recordQueryParams(zone string, view string) *ibclient.QueryParams {
searchFields := map[string]string{}
if zone != "" {
searchFields["zone"] = zone
}

if view != "" {
searchFields["view"] = view
}
return ibclient.NewQueryParams(false, searchFields)
}

// Records gets the current records.
func (p *ProviderConfig) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, err error) {
zones, err := p.zones()
Expand All @@ -192,23 +204,9 @@ func (p *ProviderConfig) Records(ctx context.Context) (endpoints []*endpoint.End

for _, zone := range zones {
logrus.Debugf("fetch records from zone '%s'", zone.Fqdn)

view := p.view
if view == "" {
view = "default"
}
searchParams := ibclient.NewQueryParams(
false,
map[string]string{
"zone": zone.Fqdn,
"view": view,
},
)

searchParams := recordQueryParams(zone.Fqdn, p.view)
var resA []ibclient.RecordA
objA := ibclient.NewEmptyRecordA()
objA.View = p.view
objA.Zone = zone.Fqdn
err = p.client.GetObject(objA, "", searchParams, &resA)
if err != nil && !isNotFoundError(err) {
return nil, fmt.Errorf("could not fetch A records from zone '%s': %s", zone.Fqdn, err)
Expand Down Expand Up @@ -253,8 +251,6 @@ func (p *ProviderConfig) Records(ctx context.Context) (endpoints []*endpoint.End
// Include Host records since they should be treated synonymously with A records
var resH []ibclient.HostRecord
objH := ibclient.NewEmptyHostRecord()
objH.View = p.view
objH.Zone = zone.Fqdn
err = p.client.GetObject(objH, "", searchParams, &resH)
if err != nil && !isNotFoundError(err) {
return nil, fmt.Errorf("could not fetch host records from zone '%s': %s", zone.Fqdn, err)
Expand All @@ -275,8 +271,6 @@ func (p *ProviderConfig) Records(ctx context.Context) (endpoints []*endpoint.End

var resC []ibclient.RecordCNAME
objC := ibclient.NewEmptyRecordCNAME()
objC.View = p.view
objC.Zone = zone.Fqdn
err = p.client.GetObject(objC, "", searchParams, &resC)
if err != nil && !isNotFoundError(err) {
return nil, fmt.Errorf("could not fetch CNAME records from zone '%s': %s", zone.Fqdn, err)
Expand All @@ -294,9 +288,7 @@ func (p *ProviderConfig) Records(ctx context.Context) (endpoints []*endpoint.End
if err == nil {
var resP []ibclient.RecordPTR
objP := ibclient.NewEmptyRecordPTR()
objP.Zone = arpaZone
objP.View = p.view
err = p.client.GetObject(objP, "", searchParams, &resP)
err = p.client.GetObject(objP, "", recordQueryParams(arpaZone, p.view), &resP)
if err != nil && !isNotFoundError(err) {
return nil, fmt.Errorf("could not fetch PTR records from zone '%s': %s", zone.Fqdn, err)
}
Expand All @@ -307,12 +299,7 @@ func (p *ProviderConfig) Records(ctx context.Context) (endpoints []*endpoint.End
}

var resT []ibclient.RecordTXT
objT := ibclient.NewRecordTXT(
ibclient.RecordTXT{
Zone: zone.Fqdn,
View: p.view,
},
)
objT := ibclient.NewRecordTXT(ibclient.RecordTXT{})
err = p.client.GetObject(objT, "", searchParams, &resT)
if err != nil && !isNotFoundError(err) {
return nil, fmt.Errorf("could not fetch TXT records from zone '%s': %s", zone.Fqdn, err)
Expand Down Expand Up @@ -434,12 +421,9 @@ func (p *ProviderConfig) ApplyChanges(ctx context.Context, changes *plan.Changes

func (p *ProviderConfig) zones() ([]ibclient.ZoneAuth, error) {
var res, result []ibclient.ZoneAuth
obj := ibclient.NewZoneAuth(
ibclient.ZoneAuth{
View: p.view,
},
)
err := p.client.GetObject(obj, "", nil, &res)
obj := ibclient.NewZoneAuth(ibclient.ZoneAuth{})
queryParams := recordQueryParams("", p.view)
err := p.client.GetObject(obj, "", queryParams, &res)
if err != nil && !isNotFoundError(err) {
return nil, err
}
Expand Down
157 changes: 149 additions & 8 deletions provider/infoblox/infoblox_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@ limitations under the License.
package infoblox

import (
"bytes"
"context"
"encoding/base64"
"fmt"
"net/http"
"net/url"
"regexp"
"strings"
"testing"
Expand All @@ -41,6 +43,68 @@ type mockIBConnector struct {
createdEndpoints []*endpoint.Endpoint
deletedEndpoints []*endpoint.Endpoint
updatedEndpoints []*endpoint.Endpoint
getObjectRequests []*getObjectRequest
requestBuilder ExtendedRequestBuilder
}

type getObjectRequest struct {
obj string
ref string
queryParams string
url url.URL
verified bool
}

func (req *getObjectRequest) ExpectRequestURLQueryParam(t *testing.T, name string, value string) *getObjectRequest {
if req.url.Query().Get(name) != value {
t.Errorf("Expected GetObject Request URL to contain query parameter %s=%s, Got: %v", name, value, req.url.Query())
}

return req
}

func (req *getObjectRequest) ExpectNotRequestURLQueryParam(t *testing.T, name string) *getObjectRequest {
if req.url.Query().Has(name) {
t.Errorf("Expected GetObject Request URL not to contain query parameter %s, Got: %v", name, req.url.Query())
}

return req
}

func (client *mockIBConnector) verifyGetObjectRequest(t *testing.T, obj string, ref string, query *map[string]string) *getObjectRequest {
qp := ""
if query != nil {
qp = fmt.Sprint(ibclient.NewQueryParams(false, *query))
}

for _, req := range client.getObjectRequests {
if !req.verified && req.obj == obj && req.ref == ref && req.queryParams == qp {
req.verified = true
return req
}
}

t.Errorf("Expected GetObject obj=%s, query=%s, ref=%s", obj, qp, ref)
return &getObjectRequest{}
}

// verifyNoMoreGetObjectRequests will assert that all "GetObject" calls have been verified.
func (client *mockIBConnector) verifyNoMoreGetObjectRequests(t *testing.T) {
unverified := []getObjectRequest{}
for _, req := range client.getObjectRequests {
if !req.verified {
unverified = append(unverified, *req)
}
}

if len(unverified) > 0 {
b := new(bytes.Buffer)
for _, req := range unverified {
fmt.Fprintf(b, "obj=%s, ref=%s, params=%s (url=%s)\n", req.obj, req.ref, req.queryParams, req.url.String())
}

t.Errorf("Unverified GetObject Requests: %v", unverified)
}
}

func (client *mockIBConnector) CreateObject(obj ibclient.IBObject) (ref string, err error) {
Expand Down Expand Up @@ -115,6 +179,18 @@ func (client *mockIBConnector) CreateObject(obj ibclient.IBObject) (ref string,
}

func (client *mockIBConnector) GetObject(obj ibclient.IBObject, ref string, queryParams *ibclient.QueryParams, res interface{}) (err error) {
req := getObjectRequest{
obj: obj.ObjectType(),
ref: ref,
}
if queryParams != nil {
req.queryParams = fmt.Sprint(queryParams)
}
r, _ := client.requestBuilder.BuildRequest(ibclient.GET, obj, ref, queryParams)
if r != nil {
req.url = *r.URL
}
client.getObjectRequests = append(client.getObjectRequests, &req)
switch obj.ObjectType() {
case "record:a":
var result []ibclient.RecordA
Expand Down Expand Up @@ -383,13 +459,14 @@ func createMockInfobloxObject(name, recordType, value string) ibclient.IBObject
return nil
}

func newInfobloxProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, dryRun bool, createPTR bool, client ibclient.IBConnector) *ProviderConfig {
func newInfobloxProvider(domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, view string, dryRun bool, createPTR bool, client ibclient.IBConnector) *ProviderConfig {
return &ProviderConfig{
client: client,
domainFilter: domainFilter,
zoneIDFilter: zoneIDFilter,
dryRun: dryRun,
createPTR: createPTR,
view: view,
}
}

Expand Down Expand Up @@ -417,7 +494,7 @@ func TestInfobloxRecords(t *testing.T) {
},
}

providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, false, &client)
providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), "", true, false, &client)
actual, err := providerCfg.Records(context.Background())
if err != nil {
t.Fatal(err)
Expand All @@ -437,6 +514,70 @@ func TestInfobloxRecords(t *testing.T) {
endpoint.NewEndpoint("host.example.com", endpoint.RecordTypeA, "125.1.1.1"),
}
validateEndpoints(t, actual, expected)
client.verifyGetObjectRequest(t, "zone_auth", "", &map[string]string{}).
ExpectNotRequestURLQueryParam(t, "view").
ExpectNotRequestURLQueryParam(t, "zone")
client.verifyGetObjectRequest(t, "record:a", "", &map[string]string{"zone": "example.com"}).
ExpectRequestURLQueryParam(t, "zone", "example.com")
client.verifyGetObjectRequest(t, "record:host", "", &map[string]string{"zone": "example.com"}).
ExpectRequestURLQueryParam(t, "zone", "example.com")
client.verifyGetObjectRequest(t, "record:cname", "", &map[string]string{"zone": "example.com"}).
ExpectRequestURLQueryParam(t, "zone", "example.com")
client.verifyGetObjectRequest(t, "record:txt", "", &map[string]string{"zone": "example.com"}).
ExpectRequestURLQueryParam(t, "zone", "example.com")
client.verifyNoMoreGetObjectRequests(t)
}

func TestInfobloxRecordsWithView(t *testing.T) {
client := mockIBConnector{
mockInfobloxZones: &[]ibclient.ZoneAuth{
createMockInfobloxZone("foo.example.com"),
createMockInfobloxZone("bar.example.com"),
},
mockInfobloxObjects: &[]ibclient.IBObject{
createMockInfobloxObject("cat.foo.example.com", endpoint.RecordTypeA, "123.123.123.122"),
createMockInfobloxObject("dog.bar.example.com", endpoint.RecordTypeA, "123.123.123.123"),
},
}

providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"foo.example.com", "bar.example.com"}), provider.NewZoneIDFilter([]string{""}), "Inside", true, false, &client)
actual, err := providerCfg.Records(context.Background())
if err != nil {
t.Fatal(err)
}
expected := []*endpoint.Endpoint{
endpoint.NewEndpoint("cat.foo.example.com", endpoint.RecordTypeA, "123.123.123.122"),
endpoint.NewEndpoint("dog.bar.example.com", endpoint.RecordTypeA, "123.123.123.123"),
}
validateEndpoints(t, actual, expected)
client.verifyGetObjectRequest(t, "zone_auth", "", &map[string]string{"view": "Inside"}).
ExpectRequestURLQueryParam(t, "view", "Inside").
ExpectNotRequestURLQueryParam(t, "zone")
client.verifyGetObjectRequest(t, "record:a", "", &map[string]string{"zone": "foo.example.com", "view": "Inside"}).
ExpectRequestURLQueryParam(t, "zone", "foo.example.com").
ExpectRequestURLQueryParam(t, "view", "Inside")
client.verifyGetObjectRequest(t, "record:host", "", &map[string]string{"zone": "foo.example.com", "view": "Inside"}).
ExpectRequestURLQueryParam(t, "zone", "foo.example.com").
ExpectRequestURLQueryParam(t, "view", "Inside")
client.verifyGetObjectRequest(t, "record:cname", "", &map[string]string{"zone": "foo.example.com", "view": "Inside"}).
ExpectRequestURLQueryParam(t, "zone", "foo.example.com").
ExpectRequestURLQueryParam(t, "view", "Inside")
client.verifyGetObjectRequest(t, "record:txt", "", &map[string]string{"zone": "foo.example.com", "view": "Inside"}).
ExpectRequestURLQueryParam(t, "zone", "foo.example.com").
ExpectRequestURLQueryParam(t, "view", "Inside")
client.verifyGetObjectRequest(t, "record:a", "", &map[string]string{"zone": "bar.example.com", "view": "Inside"}).
ExpectRequestURLQueryParam(t, "zone", "bar.example.com").
ExpectRequestURLQueryParam(t, "view", "Inside")
client.verifyGetObjectRequest(t, "record:host", "", &map[string]string{"zone": "bar.example.com", "view": "Inside"}).
ExpectRequestURLQueryParam(t, "zone", "bar.example.com").
ExpectRequestURLQueryParam(t, "view", "Inside")
client.verifyGetObjectRequest(t, "record:cname", "", &map[string]string{"zone": "bar.example.com", "view": "Inside"}).
ExpectRequestURLQueryParam(t, "zone", "bar.example.com").
ExpectRequestURLQueryParam(t, "view", "Inside")
client.verifyGetObjectRequest(t, "record:txt", "", &map[string]string{"zone": "bar.example.com", "view": "Inside"}).
ExpectRequestURLQueryParam(t, "zone", "bar.example.com").
ExpectRequestURLQueryParam(t, "view", "Inside")
client.verifyNoMoreGetObjectRequests(t)
}

func TestInfobloxAdjustEndpoints(t *testing.T) {
Expand All @@ -453,7 +594,7 @@ func TestInfobloxAdjustEndpoints(t *testing.T) {
},
}

providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), true, true, &client)
providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"example.com"}), provider.NewZoneIDFilter([]string{""}), "", true, true, &client)
actual, err := providerCfg.Records(context.Background())
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -481,7 +622,7 @@ func TestInfobloxRecordsReverse(t *testing.T) {
},
}

providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"10.0.0.0/24"}), provider.NewZoneIDFilter([]string{""}), true, true, &client)
providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"10.0.0.0/24"}), provider.NewZoneIDFilter([]string{""}), "", true, true, &client)
actual, err := providerCfg.Records(context.Background())
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -588,6 +729,7 @@ func testInfobloxApplyChangesInternal(t *testing.T, dryRun, createPTR bool, clie
providerCfg := newInfobloxProvider(
endpoint.NewDomainFilter([]string{""}),
provider.NewZoneIDFilter([]string{""}),
"",
dryRun,
createPTR,
client,
Expand Down Expand Up @@ -653,7 +795,7 @@ func TestInfobloxZones(t *testing.T) {
mockInfobloxObjects: &[]ibclient.IBObject{},
}

providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"example.com", "1.2.3.0/24"}), provider.NewZoneIDFilter([]string{""}), true, false, &client)
providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"example.com", "1.2.3.0/24"}), provider.NewZoneIDFilter([]string{""}), "", true, false, &client)
zones, _ := providerCfg.zones()
var emptyZoneAuth *ibclient.ZoneAuth
assert.Equal(t, providerCfg.findZone(zones, "example.com").Fqdn, "example.com")
Expand All @@ -677,7 +819,7 @@ func TestInfobloxReverseZones(t *testing.T) {
mockInfobloxObjects: &[]ibclient.IBObject{},
}

providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"example.com", "1.2.3.0/24", "10.0.0.0/8"}), provider.NewZoneIDFilter([]string{""}), true, false, &client)
providerCfg := newInfobloxProvider(endpoint.NewDomainFilter([]string{"example.com", "1.2.3.0/24", "10.0.0.0/8"}), provider.NewZoneIDFilter([]string{""}), "", true, false, &client)
zones, _ := providerCfg.zones()
var emptyZoneAuth *ibclient.ZoneAuth
assert.Equal(t, providerCfg.findReverseZone(zones, "nomatch-example.com"), emptyZoneAuth)
Expand Down Expand Up @@ -738,7 +880,6 @@ func TestExtendedRequestNameRegExBuilder(t *testing.T) {
assert.True(t, req.URL.Query().Get("name~") == "")
}


func TestExtendedRequestMaxResultsBuilder(t *testing.T) {
hostCfg := ibclient.HostConfig{
Host: "localhost",
Expand Down Expand Up @@ -774,7 +915,7 @@ func TestGetObject(t *testing.T) {
requestor := mockRequestor{}
client, _ := ibclient.NewConnector(hostCfg, authCfg, transportConfig, requestBuilder, &requestor)

providerConfig := newInfobloxProvider(endpoint.NewDomainFilter([]string{"mysite.com"}), provider.NewZoneIDFilter([]string{""}), true, true, client)
providerConfig := newInfobloxProvider(endpoint.NewDomainFilter([]string{"mysite.com"}), provider.NewZoneIDFilter([]string{""}), "", true, true, client)

providerConfig.deleteRecords(infobloxChangeMap{
"myzone.com": []*endpoint.Endpoint{
Expand Down

0 comments on commit 0f2d419

Please sign in to comment.