Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Webhook provider #3063

Merged
merged 23 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/tutorials/webhook-provider.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The following table represents the methods to implement mapped to their HTTP met
| AdjustEndpoints | POST | /adjustendpoints |
| ApplyChanges | POST | /records |

ExternalDNS will also make requests to the `/` endpoint for negotatiation and for deseliarization of the `DomainFilter`.
ExternalDNS will also make requests to the `/` endpoint for negotiation and for deserialization of the `DomainFilter`.

The server needs to respond to those requests by reading the `Accept` header and responding with a corresponding `Content-Type` header specifying the supported media type format and version.

Expand Down
129 changes: 124 additions & 5 deletions provider/webhook/httpapi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
Expand All @@ -31,34 +32,84 @@ import (
"sigs.k8s.io/external-dns/plan"
)

type FakeWebhookProvider struct{}
var records []*endpoint.Endpoint

type FakeWebhookProvider struct {
err error
domainFilter endpoint.DomainFilter
}

func (p FakeWebhookProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
return []*endpoint.Endpoint{}, nil
if p.err != nil {
return nil, p.err
}
return records, nil
}

func (p FakeWebhookProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
if p.err != nil {
return p.err
}
records = append(records, changes.Create...)
return nil
}

func (p FakeWebhookProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
// for simplicity, we do not adjust endpoints in this test
if p.err != nil {
return nil, p.err
}
return endpoints, nil
}

func (p FakeWebhookProvider) GetDomainFilter() endpoint.DomainFilter {
return endpoint.DomainFilter{}
return p.domainFilter
}

func TestMain(m *testing.M) {
records = []*endpoint.Endpoint{
{
DNSName: "foo.bar.com",
RecordType: "A",
},
}
m.Run()
}

func TestRecordsHandlerRecords(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/records", nil)
w := httptest.NewRecorder()

providerAPIServer := &WebhookServer{
provider: &FakeWebhookProvider{},
provider: &FakeWebhookProvider{
domainFilter: endpoint.NewDomainFilter([]string{"foo.bar.com"}),
},
}
providerAPIServer.recordsHandler(w, req)
res := w.Result()
require.Equal(t, http.StatusOK, res.StatusCode)
// require that the res has the same endpoints as the records slice
defer res.Body.Close()
require.NotNil(t, res.Body)
endpoints := []*endpoint.Endpoint{}
if err := json.NewDecoder(res.Body).Decode(&endpoints); err != nil {
t.Errorf("Failed to decode response body: %s", err.Error())
}
require.Equal(t, records, endpoints)
}

func TestRecordsHandlerRecordsWithErrors(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/records", nil)
w := httptest.NewRecorder()

providerAPIServer := &WebhookServer{
provider: &FakeWebhookProvider{
err: fmt.Errorf("error"),
},
}
providerAPIServer.recordsHandler(w, req)
res := w.Result()
require.Equal(t, http.StatusInternalServerError, res.StatusCode)
}

func TestRecordsHandlerApplyChangesWithBadRequest(t *testing.T) {
Expand Down Expand Up @@ -99,6 +150,46 @@ func TestRecordsHandlerApplyChangesWithValidRequest(t *testing.T) {
require.Equal(t, http.StatusNoContent, res.StatusCode)
}

func TestRecordsHandlerApplyChangesWithErrors(t *testing.T) {
changes := &plan.Changes{
Create: []*endpoint.Endpoint{
{
DNSName: "foo.bar.com",
RecordType: "A",
Targets: endpoint.Targets{},
},
},
}
j, err := json.Marshal(changes)
require.NoError(t, err)

reader := bytes.NewReader(j)

req := httptest.NewRequest(http.MethodPost, "/applychanges", reader)
w := httptest.NewRecorder()

providerAPIServer := &WebhookServer{
provider: &FakeWebhookProvider{
err: fmt.Errorf("error"),
},
}
providerAPIServer.recordsHandler(w, req)
res := w.Result()
require.Equal(t, http.StatusInternalServerError, res.StatusCode)
}

func TestRecordsHandlerWithWrongHTTPMethod(t *testing.T) {
req := httptest.NewRequest(http.MethodPut, "/records", nil)
w := httptest.NewRecorder()

providerAPIServer := &WebhookServer{
provider: &FakeWebhookProvider{},
}
providerAPIServer.recordsHandler(w, req)
res := w.Result()
require.Equal(t, http.StatusBadRequest, res.StatusCode)
}

func TestAdjustEndpointsHandlerWithInvalidRequest(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/adjustendpoints", nil)
w := httptest.NewRecorder()
Expand All @@ -117,7 +208,7 @@ func TestAdjustEndpointsHandlerWithInvalidRequest(t *testing.T) {
require.Equal(t, http.StatusBadRequest, res.StatusCode)
}

func TestAdjustEndpointsWithValidRequest(t *testing.T) {
func TestAdjustEndpointsHandlerWithValidRequest(t *testing.T) {
pve := []*endpoint.Endpoint{
{
DNSName: "foo.bar.com",
Expand All @@ -143,6 +234,34 @@ func TestAdjustEndpointsWithValidRequest(t *testing.T) {
require.NotNil(t, res.Body)
}

func TestAdjustEndpointsHandlerWithError(t *testing.T) {
pve := []*endpoint.Endpoint{
{
DNSName: "foo.bar.com",
RecordType: "A",
Targets: endpoint.Targets{},
RecordTTL: 0,
},
}

j, err := json.Marshal(pve)
require.NoError(t, err)

reader := bytes.NewReader(j)
req := httptest.NewRequest(http.MethodPost, "/adjustendpoints", reader)
w := httptest.NewRecorder()

providerAPIServer := &WebhookServer{
provider: &FakeWebhookProvider{
err: fmt.Errorf("error"),
},
}
providerAPIServer.adjustEndpointsHandler(w, req)
res := w.Result()
require.Equal(t, http.StatusInternalServerError, res.StatusCode)
require.NotNil(t, res.Body)
}

func TestStartHTTPApi(t *testing.T) {
startedChan := make(chan struct{})
go StartHTTPApi(FakeWebhookProvider{}, startedChan, 5*time.Second, 10*time.Second, "127.0.0.1:8887")
Expand Down
2 changes: 1 addition & 1 deletion provider/webhook/webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ func (p WebhookProvider) AdjustEndpoints(e []*endpoint.Endpoint) ([]*endpoint.En
if resp.StatusCode != http.StatusOK {
adjustEndpointsErrorsGauge.Inc()
log.Debugf("Failed to AdjustEndpoints with code %d", resp.StatusCode)
return nil, err
return nil, fmt.Errorf("failed to AdjustEndpoints with code %d", resp.StatusCode)
Raffo marked this conversation as resolved.
Show resolved Hide resolved
}

if err := json.NewDecoder(resp.Body).Decode(&endpoints); err != nil {
Expand Down
56 changes: 54 additions & 2 deletions provider/webhook/webhook_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,20 @@ func TestInvalidDomainFilter(t *testing.T) {
}

func TestValidDomainfilter(t *testing.T) {
// initialize domanin filter
Raffo marked this conversation as resolved.
Show resolved Hide resolved
domainFilter := endpoint.NewDomainFilter([]string{"example.com"})
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
w.Header().Set(contentTypeHeader, mediaTypeFormatAndVersion)
w.Write([]byte(`{}`))
json.NewEncoder(w).Encode(domainFilter)
return
}
}))
defer svr.Close()

_, err := NewWebhookProvider(svr.URL)
p, err := NewWebhookProvider(svr.URL)
require.NoError(t, err)
Raffo marked this conversation as resolved.
Show resolved Hide resolved
require.Equal(t, p.GetDomainFilter(), endpoint.NewDomainFilter([]string{"example.com"}))
}

func TestRecords(t *testing.T) {
Expand All @@ -66,6 +69,7 @@ func TestRecords(t *testing.T) {
w.Write([]byte(`{}`))
return
}
require.Equal(t, "/records", r.URL.Path)
w.Write([]byte(`[{
Raffo marked this conversation as resolved.
Show resolved Hide resolved
"dnsName" : "test.example.com"
}]`))
Expand All @@ -82,6 +86,24 @@ func TestRecords(t *testing.T) {
}}, endpoints)
}

Raffo marked this conversation as resolved.
Show resolved Hide resolved
func TestRecordsWithErrors(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
w.Header().Set(contentTypeHeader, mediaTypeFormatAndVersion)
w.Write([]byte(`{}`))
return
}
require.Equal(t, "/records", r.URL.Path)
w.WriteHeader(http.StatusInternalServerError)
}))
defer svr.Close()

provider, err := NewWebhookProvider(svr.URL)
require.NoError(t, err)
_, err = provider.Records(context.Background())
require.NotNil(t, err)
}

func TestApplyChanges(t *testing.T) {
successfulApplyChanges := true
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -90,6 +112,7 @@ func TestApplyChanges(t *testing.T) {
w.Write([]byte(`{}`))
return
}
require.Equal(t, "/records", r.URL.Path)
if successfulApplyChanges {
w.WriteHeader(http.StatusNoContent)
} else {
Expand All @@ -116,6 +139,8 @@ func TestAdjustEndpoints(t *testing.T) {
w.Write([]byte(`{}`))
return
}
require.Equal(t, "/adjustendpoints", r.URL.Path)

var endpoints []*endpoint.Endpoint
defer r.Body.Close()
b, err := io.ReadAll(r.Body)
Expand Down Expand Up @@ -158,5 +183,32 @@ func TestAdjustEndpoints(t *testing.T) {
"",
},
}}, adjustedEndpoints)
}

func TestAdjustendpointsWithError(t *testing.T) {
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/" {
w.Header().Set(contentTypeHeader, mediaTypeFormatAndVersion)
w.Write([]byte(`{}`))
return
}
require.Equal(t, "/adjustendpoints", r.URL.Path)
w.WriteHeader(http.StatusInternalServerError)
}))
defer svr.Close()

provider, err := NewWebhookProvider(svr.URL)
require.NoError(t, err)
endpoints := []*endpoint.Endpoint{
{
DNSName: "test.example.com",
RecordTTL: 10,
RecordType: "A",
Targets: endpoint.Targets{
"",
},
},
}
_, err = provider.AdjustEndpoints(endpoints)
require.Error(t, err)
}
Loading