diff --git a/example_test.go b/example_test.go index 09e767f..0d1e3ff 100644 --- a/example_test.go +++ b/example_test.go @@ -38,11 +38,11 @@ func ExampleBasicAuth() { // nolint: govet // Create the listener. r := webhook.Registration{ - Address: server.URL, // Replace this with a real address. Duration: webhook.CustomDuration(5 * time.Minute), } - whl, err := listener.New(&r, + url := server.URL // replace with the URL of the webhook provider + whl, err := listener.New(&r, url, listener.AuthBasic("username", "password"), listener.AcceptSHA1(), listener.AcceptedSecrets("foobar", "carport"), diff --git a/functional_test.go b/functional_test.go index 3e72f35..a63638a 100644 --- a/functional_test.go +++ b/functional_test.go @@ -56,16 +56,17 @@ func TestNormalUsage(t *testing.T) { defer server.Close() // Create the listener. - whl, err := New(&webhook.Registration{ - Address: server.URL, - Events: []string{ - "foo", - }, - Config: webhook.DeliveryConfig{ - Secret: "secret1", + whl, err := New( + &webhook.Registration{ + Events: []string{ + "foo", + }, + Config: webhook.DeliveryConfig{ + Secret: "secret1", + }, + Duration: webhook.CustomDuration(5 * time.Minute), }, - Duration: webhook.CustomDuration(5 * time.Minute), - }, + server.URL, Interval(1*time.Millisecond), AuthBasic("user", "pass"), ) @@ -148,16 +149,17 @@ func TestSingleShotUsage(t *testing.T) { defer server.Close() // Create the listener. - whl, err := New(&webhook.Registration{ - Address: server.URL, - Events: []string{ - "foo", - }, - Config: webhook.DeliveryConfig{ - Secret: "secret1", + whl, err := New( + &webhook.Registration{ + Events: []string{ + "foo", + }, + Config: webhook.DeliveryConfig{ + Secret: "secret1", + }, + Duration: webhook.CustomDuration(5 * time.Minute), }, - Duration: webhook.CustomDuration(5 * time.Minute), - }, + server.URL, Once(), ) require.NotNil(whl) @@ -172,7 +174,7 @@ func TestSingleShotUsage(t *testing.T) { assert.NoError(err) // Wait a bit then roll the secret.. - time.Sleep(time.Millisecond) + time.Sleep(10 * time.Millisecond) m.Lock() expectSecret = append(expectSecret, "secret2", "secret3", "secret4") m.Unlock() @@ -187,13 +189,13 @@ func TestSingleShotUsage(t *testing.T) { assert.NoError(err) // Wait a bit then remove the prior secret from the list of accepted secrets. - time.Sleep(time.Millisecond) + time.Sleep(10 * time.Millisecond) m.Lock() expectSecret = []string{"secret5"} m.Unlock() // Wait a bit then unregister. - time.Sleep(time.Millisecond) + time.Sleep(10 * time.Millisecond) whl.Stop() // Re-stop because it could happen. @@ -214,16 +216,17 @@ func TestFailedHTTPCall(t *testing.T) { defer server.Close() // Create the listener. - whl, err := New(&webhook.Registration{ - Address: server.URL, - Events: []string{ - "foo", - }, - Config: webhook.DeliveryConfig{ - Secret: "secret1", + whl, err := New( + &webhook.Registration{ + Events: []string{ + "foo", + }, + Config: webhook.DeliveryConfig{ + Secret: "secret1", + }, + Duration: webhook.CustomDuration(5 * time.Minute), }, - Duration: webhook.CustomDuration(5 * time.Minute), - }, + server.URL, Once(), ) @@ -240,15 +243,17 @@ func TestFailedAuthCheck(t *testing.T) { require := require.New(t) // Create the listener. - whl, err := New(&webhook.Registration{ - Events: []string{ - "foo", - }, - Config: webhook.DeliveryConfig{ - Secret: "secret1", + whl, err := New( + &webhook.Registration{ + Events: []string{ + "foo", + }, + Config: webhook.DeliveryConfig{ + Secret: "secret1", + }, + Duration: webhook.CustomDuration(5 * time.Minute), }, - Duration: webhook.CustomDuration(5 * time.Minute), - }, + "http://example.com", AuthBearerFunc(func() (string, error) { return "", fmt.Errorf("nope") }), @@ -267,16 +272,18 @@ func TestFailedNewRequest(t *testing.T) { require := require.New(t) // Create the listener. - whl, err := New(&webhook.Registration{ - Address: "//invalid::localhost/:99999", - Events: []string{ - "foo", - }, - Config: webhook.DeliveryConfig{ - Secret: "secret1", + whl, err := New( + &webhook.Registration{ + Events: []string{ + "foo", + }, + Config: webhook.DeliveryConfig{ + Secret: "secret1", + }, + Duration: webhook.CustomDuration(5 * time.Minute), }, - Duration: webhook.CustomDuration(5 * time.Minute), - }) + "//invalid::localhost/:99999", + ) require.NotNil(whl) require.NoError(err) @@ -300,16 +307,17 @@ func TestFailedConnect(t *testing.T) { defer server.Close() // Create the listener. - whl, err := New(&webhook.Registration{ - Address: server.URL, - Events: []string{ - "foo", - }, - Config: webhook.DeliveryConfig{ - Secret: "secret1", + whl, err := New( + &webhook.Registration{ + Events: []string{ + "foo", + }, + Config: webhook.DeliveryConfig{ + Secret: "secret1", + }, + Duration: webhook.CustomDuration(5 * time.Minute), }, - Duration: webhook.CustomDuration(5 * time.Minute), - }, + server.URL, HTTPClient(&http.Client{Timeout: 1 * time.Millisecond}), Once(), ) @@ -321,3 +329,57 @@ func TestFailedConnect(t *testing.T) { err = whl.Register() assert.ErrorIs(err, ErrRegistrationFailed) } + +func TestFailsAfterABit(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + var m sync.Mutex + var count int + + server := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + r.Body.Close() + + m.Lock() + if count == 0 { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusBadRequest) + } + count++ + m.Unlock() + }, + ), + ) + defer server.Close() + + // Create the listener. + whl, err := New( + &webhook.Registration{ + Events: []string{ + "foo", + }, + Config: webhook.DeliveryConfig{ + Secret: "secret1", + }, + Duration: webhook.CustomDuration(5 * time.Minute), + }, + server.URL, + Interval(1*time.Millisecond), + AuthBasic("user", "pass"), + ) + require.NotNil(whl) + require.NoError(err) + + // Register the webhook before has started + err = whl.Register() + assert.NoError(err) + + // Wait a bit then roll the secret.. + time.Sleep(10 * time.Millisecond) + + whl.Stop() +} diff --git a/listener.go b/listener.go index 300ca7e..43c78d2 100644 --- a/listener.go +++ b/listener.go @@ -33,6 +33,7 @@ const ( type Listener struct { m sync.RWMutex registration *webhook.Registration + webhookURL string registrationOpts []webhook.Option interval time.Duration client *http.Client @@ -57,9 +58,19 @@ type Option interface { } // New creates a new webhook listener with the given registration and options. -func New(r *webhook.Registration, opts ...Option) (*Listener, error) { +func New(r *webhook.Registration, url string, opts ...Option) (*Listener, error) { + if r == nil { + return nil, fmt.Errorf("%w: registration is required", ErrInput) + } + + url = strings.TrimSpace(url) + if url == "" { + return nil, fmt.Errorf("%w: webhook url is required", ErrInput) + } + l := Listener{ registration: r, + webhookURL: url, registrationOpts: make([]webhook.Option, 0), logger: zap.NewNop(), client: http.DefaultClient, @@ -150,7 +161,14 @@ func (l *Listener) use(secret string) error { } if l.update != nil { - l.update <- struct{}{} + go func() { + defer func() { + // If the update channel is closed, the listener is shutting down. + // Ignore the panic. + _ = recover() + }() + l.update <- struct{}{} + }() } return nil @@ -183,6 +201,7 @@ func (l *Listener) run() { select { case <-l.ctx.Done(): l.m.Lock() + close(l.update) l.update = nil l.m.Unlock() return @@ -224,7 +243,7 @@ func (l *Listener) register(locked bool) error { } fn := l.getAuth - address := l.registration.Address + address := l.webhookURL body := l.body if !locked { @@ -273,7 +292,7 @@ func (l *Listener) register(locked bool) error { // Tokenize parses the token from the request header. If the token is not found // or is invalid, an error is returned. -func (l *Listener) Tokenize(r *http.Request) (*Token, error) { +func (l *Listener) Tokenize(r *http.Request) (*token, error) { headers := r.Header.Values(xmidtHeader) if len(headers) != 0 { l.metrics.TokenHeaderUsed.inc(xmidtHeader) @@ -325,12 +344,17 @@ func (l *Listener) Tokenize(r *http.Request) (*Token, error) { l.metrics.TokenAlgorithmUsed.inc(best) l.metrics.TokenOutcome.incValid() - return NewToken(best, choices[best]), nil + return newToken(best, choices[best]), nil } // Authorize validates that the request body matches the hash and secret provided // in the token. func (l *Listener) Authorize(r *http.Request, t Token) error { + if t == nil { + l.metrics.Authorization.incInvalidToken() + return fmt.Errorf("%w: invalid token", ErrInput) + } + secret, err := hex.DecodeString(t.Principal()) if err != nil { l.metrics.Authorization.incInvalidSignature() diff --git a/listener_test.go b/listener_test.go index f47874c..c074c31 100644 --- a/listener_test.go +++ b/listener_test.go @@ -18,13 +18,15 @@ import ( type vador func(*assert.Assertions, *Listener) type newTest struct { - description string - r webhook.Registration - opt Option - opts []Option - check vador - checks []vador - expectedErr error + description string + r webhook.Registration + noRegistration bool + noUrl bool + opt Option + opts []Option + check vador + checks []vador + expectedErr error } func vadorBody(assert *assert.Assertions, l *Listener) { @@ -62,6 +64,15 @@ func TestNew(t *testing.T) { { description: "empty is not ok", expectedErr: ErrInput, + }, { + description: "no url fails", + r: validWHR, + noUrl: true, + expectedErr: ErrInput, + }, { + description: "nil registration fails", + noRegistration: true, + expectedErr: ErrInput, }, { description: "nearly empty is ok", r: validWHR, @@ -92,12 +103,21 @@ func commonNewTest(t *testing.T, tests []newTest) { for _, tc := range tests { t.Run(tc.description, func(t *testing.T) { assert := assert.New(t) + require := require.New(t) r := tc.r opts := make([]Option, 0, len(tc.opts)+1) opts = append(opts, tc.opt) opts = append(opts, tc.opts...) - got, err := New(&r, opts...) + url := "http://example.com" + if tc.noUrl { + url = "" + } + rPtr := &r + if tc.noRegistration { + rPtr = nil + } + got, err := New(rPtr, url, opts...) if tc.expectedErr != nil { assert.Nil(got) @@ -105,6 +125,9 @@ func commonNewTest(t *testing.T, tests []newTest) { return } + require.NotNil(got) + require.NoError(err) + checks := make([]vador, 0, len(tc.checks)+1) checks = append(checks, tc.check) checks = append(checks, tc.checks...) @@ -134,7 +157,7 @@ func TestTokenize(t *testing.T) { }, }, opt: AcceptSHA1(), - expected: Token{ + expected: token{ alg: "sha1", principal: "12345", }, @@ -146,7 +169,7 @@ func TestTokenize(t *testing.T) { }, }, opt: AcceptSHA1(), - expected: Token{ + expected: token{ alg: "sha1", principal: "12345", }, @@ -161,14 +184,14 @@ func TestTokenize(t *testing.T) { AcceptSHA1(), AcceptNoHash(), }, - expected: Token{ + expected: token{ alg: "sha1", principal: "12345", }, }, { description: "no header with that name", opt: AcceptNoHash(), - expected: Token{ + expected: token{ alg: "none", principal: "", }, @@ -180,7 +203,7 @@ func TestTokenize(t *testing.T) { }, }, opt: AcceptNoHash(), - expected: Token{ + expected: token{ alg: "none", principal: "", }, @@ -236,6 +259,7 @@ func TestTokenize(t *testing.T) { whl, err := New(&webhook.Registration{ Duration: webhook.CustomDuration(5 * time.Minute), }, + "http://example.com", opts..., ) @@ -277,7 +301,7 @@ func TestAuthorize(t *testing.T) { AcceptSHA1(), AcceptedSecrets("123456"), }, - token: Token{ + token: token{ alg: "sha1", principal: "f76a55b14b2b3bd08116b4ee857dd6439b507317", }, @@ -290,7 +314,7 @@ func TestAuthorize(t *testing.T) { AcceptSHA1(), AcceptedSecrets("123456"), }, - token: Token{ + token: token{ alg: "sha1", principal: "0000", }, @@ -305,6 +329,10 @@ func TestAuthorize(t *testing.T) { AcceptedSecrets("123456"), }, expectedErr: ErrInput, + token: token{ + alg: "sha1", + principal: "0000", + }, }, { description: "no body", opts: []Option{ @@ -312,19 +340,26 @@ func TestAuthorize(t *testing.T) { AcceptedSecrets("123456"), }, expectedErr: ErrInput, + token: token{ + alg: "sha1", + principal: "0000", + }, }, { description: "invalid principle", - token: Token{ + token: token{ alg: "sha1", principal: "f", // invalid because it needs to be 2 characters. }, expectedErr: ErrInput, + }, { + description: "nil token", + expectedErr: ErrInput, }, { description: "no matching hash", input: http.Request{ Body: io.NopCloser(strings.NewReader("foo")), }, - token: Token{ + token: token{ alg: "sha1", principal: "f0", }, @@ -338,9 +373,11 @@ func TestAuthorize(t *testing.T) { require := require.New(t) opts := append(tc.opts, tc.opt) - whl, err := New(&webhook.Registration{ - Duration: webhook.CustomDuration(5 * time.Minute), - }, + whl, err := New( + &webhook.Registration{ + Duration: webhook.CustomDuration(5 * time.Minute), + }, + "http://example.com", opts..., ) @@ -386,7 +423,7 @@ func TestListener_Accept(t *testing.T) { r := validWHR opts := append(tc.opts, tc.opt) - l, err := New(&r, opts...) + l, err := New(&r, "http://example.com", opts...) require.NotNil(l) require.NoError(err) @@ -428,7 +465,7 @@ func TestListener_String(t *testing.T) { r := validWHR opts := append(tc.opts, tc.opt) - l, err := New(&r, opts...) + l, err := New(&r, "http://example.com", opts...) require.NotNil(l) require.NoError(err) diff --git a/metrics.go b/metrics.go index ee9aaf1..670e14e 100644 --- a/metrics.go +++ b/metrics.go @@ -208,6 +208,9 @@ type MeasureAuthorization struct { // Valid is the label used when the authorization is valid. Valid string `default:"success"` + // InvalidToken is the label used when the token is nil or invalid. + InvalidToken string `default:"failure_invalid_token"` + // InvalidSignature is the label used when the signature is invalid. InvalidSignature string `default:"failure_invalid_signature"` @@ -225,6 +228,10 @@ type MeasureAuthorization struct { Counter metrics.Counter `default:"discard.NewCounter()"` } +func (m *MeasureAuthorization) incInvalidToken() { + m.Counter.With(m.Label, m.InvalidToken).Add(1) +} + func (m *MeasureAuthorization) incInvalidSignature() { m.Counter.With(m.Label, m.InvalidSignature).Add(1) } diff --git a/metrics_test.go b/metrics_test.go index 7da8c36..801a77a 100644 --- a/metrics_test.go +++ b/metrics_test.go @@ -53,6 +53,7 @@ func TestMeasure_init(t *testing.T) { Authorization: MeasureAuthorization{ Label: "outcome", Valid: "success", + InvalidToken: "failure_invalid_token", InvalidSignature: "failure_invalid_signature", EmptyBody: "failure_empty_body", UnableToReadBody: "failure_unable_to_read_body", diff --git a/token.go b/token.go index 44f8809..c56cc72 100644 --- a/token.go +++ b/token.go @@ -5,25 +5,30 @@ package listener // Token represents the information needed to authenticate the flow of incoming // webhook callbacks. -type Token struct { +type Token interface { + Type() string + Principal() string +} + +type token struct { alg string principal string } // Type returns the type of hash to use for authentication. -func (t Token) Type() string { +func (t token) Type() string { return t.alg } // Principal returns the principal (calculated value included with the message) // to use for authentication. -func (t Token) Principal() string { +func (t token) Principal() string { return t.principal } -// NewToken creates a new token with the given hash type and principal. -func NewToken(alg, principal string) *Token { - return &Token{ +// newToken creates a new token with the given hash type and principal. +func newToken(alg, principal string) *token { + return &token{ alg: alg, principal: principal, }