From 0664688821a1762446ad7e5cdeca2de5ea8827bd Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Wed, 3 Jan 2024 17:09:47 -0600 Subject: [PATCH 1/2] Backoff and Retry Client --- .github/workflows/test.yml | 2 +- LICENSE | 2 +- README.md | 19 +++++++- cmd/courier/main.go | 2 +- containers/courier/Dockerfile | 2 +- go.mod | 3 +- go.sum | 4 ++ pkg/api/v1/client.go | 74 +++++++++++++++++++++++++++--- pkg/api/v1/errors.go | 82 ++++++++++++++++++++++++++++------ pkg/api/v1/errors_test.go | 74 ++++++++++++++++++++++++++++++ pkg/api/v1/options.go | 65 +++++++++++++++++++++++++++ pkg/config/config.go | 22 ++++----- pkg/config/config_test.go | 38 ++++++++-------- pkg/secrets/client.go | 2 +- pkg/server.go | 4 +- pkg/server_test.go | 4 +- pkg/store/gcloud/store.go | 2 +- pkg/store/gcloud/store_test.go | 4 +- 18 files changed, 340 insertions(+), 65 deletions(-) create mode 100644 pkg/api/v1/errors_test.go diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 38f5147..07f89cf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,7 +16,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.20.x + go-version: 1.21.x - name: Install Staticcheck run: go install honnef.co/go/tools/cmd/staticcheck@2023.1.3 diff --git a/LICENSE b/LICENSE index b7489e9..eeb13d6 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 trisacrypto +Copyright (c) 2023 TRISA Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 9545d95..454c1ef 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,17 @@ -# courier -A standalone certificate delivery service +# Courier + +A stand-alone service that allows the GDS to deliver TRISA certificates via a webhook +rather than email. The service accepts PCKS12 passwords and encrypted certificates from +TRISA as HTTP `POST` requests and stores the certificates and passwords in either +Google Secret Manager or on the local disk (other secret management backends such as +Vault or Postgres may be available in the future). + +This tool is mostly used by TRISA Service Providers (TSPs) who have to handle many +TRISA certificate deliveries at a time. VASPs who want to automate certificate delivery +may also use this service. + +## Deploying with Docker + +The simplest way to run the courier service is to use the docker image +`trisa/courier:latest` and to configure it from the environment. This allows the +courier service to be easily run on a Kubernetes cluster. diff --git a/cmd/courier/main.go b/cmd/courier/main.go index b0aeab2..5ff2b69 100644 --- a/cmd/courier/main.go +++ b/cmd/courier/main.go @@ -277,7 +277,7 @@ func storeCertificate(c *cli.Context) (err error) { // Get a secret from the secret manager. func getSecret(c *cli.Context) (err error) { - conf := config.SecretsConfig{ + conf := config.GCPSecretsConfig{ Enabled: true, Project: c.String("project"), Credentials: c.String("credentials"), diff --git a/containers/courier/Dockerfile b/containers/courier/Dockerfile index f33c381..0fa2ed6 100644 --- a/containers/courier/Dockerfile +++ b/containers/courier/Dockerfile @@ -1,5 +1,5 @@ # Dynamic Builds -ARG BUILDER_IMAGE=golang:1.20-buster +ARG BUILDER_IMAGE=golang:1.21-buster ARG FINAL_IMAGE=debian:buster-slim # Build Stage diff --git a/go.mod b/go.mod index bcc17cb..3a2f17d 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,10 @@ module github.com/trisacrypto/courier -go 1.20 +go 1.21 require ( cloud.google.com/go/secretmanager v1.11.2 + github.com/cenkalti/backoff/v4 v4.2.1 github.com/gin-gonic/gin v1.9.1 github.com/googleapis/gax-go v1.0.3 github.com/joho/godotenv v1.5.1 diff --git a/go.sum b/go.sum index dd85ba8..527ab89 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,7 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= cloud.google.com/go v0.110.6 h1:8uYAkj3YHTP/1iwReuHPxLSbdcyc+dSBbzFMrVwDR6Q= +cloud.google.com/go v0.110.6/go.mod h1:+EYjdK8e5RME/VY/qLCAtuyALQ9q67dvuum8i+H5xsI= cloud.google.com/go/compute v1.23.0 h1:tP41Zoavr8ptEqaW6j+LQOnyBBhO7OkOMAGrgLopTwY= cloud.google.com/go/compute v1.23.0/go.mod h1:4tCnrn48xsqlwSAiLf1HXMQk8CONslYbdiEZc9FEIbM= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= @@ -16,6 +17,8 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= +github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= @@ -48,6 +51,7 @@ github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= diff --git a/pkg/api/v1/client.go b/pkg/api/v1/client.go index f3cfff1..7e817ef 100644 --- a/pkg/api/v1/client.go +++ b/pkg/api/v1/client.go @@ -4,14 +4,23 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" "net/http" "net/url" "time" + + "github.com/cenkalti/backoff/v4" ) +const DefaultRetries = 3 + +func DefaultBackoff() BackoffFactory { + return func() backoff.BackOff { + return backoff.NewExponentialBackOff() + } +} + // New creates a new API client that implements the CourierClient interface. func New(endpoint string, opts ...ClientOption) (_ CourierClient, err error) { if endpoint == "" { @@ -19,7 +28,7 @@ func New(endpoint string, opts ...ClientOption) (_ CourierClient, err error) { } // Create a client with the parsed endpoint. - c := &APIv1{} + c := &APIv1{retries: -1} if c.url, err = url.Parse(endpoint); err != nil { return nil, err } @@ -39,13 +48,26 @@ func New(endpoint string, opts ...ClientOption) (_ CourierClient, err error) { Timeout: 30 * time.Second, } } + + // If backoff hasn't been specified add the default backoff factory + if c.backoff == nil { + c.backoff = DefaultBackoff() + } + + // If retries haven't been specified add the default number of retries + if c.retries < 0 { + c.retries = DefaultRetries + } + return c, nil } // APIv1 implements the CourierClient interface. type APIv1 struct { - url *url.URL - client *http.Client + url *url.URL + client *http.Client + backoff BackoffFactory + retries int } var _ CourierClient = &APIv1{} @@ -172,8 +194,47 @@ func (c *APIv1) NewRequest(ctx context.Context, method, path string, data interf } // Do executes an http request against the server, performs error checking, and -// deserializes response data into the specified struct. +// deserializes response data into the specified struct. This function also manages +// retries using a backoff strategy. func (s *APIv1) Do(req *http.Request, data interface{}, checkStatus bool) (rep *http.Response, err error) { + attempts := 0 + start := time.Now() + ctx := req.Context() + delay := s.backoff() + errs := make([]error, 0, s.retries+1) + + for attempts <= s.retries { + attempts++ + if rep, err = s.do(req, data, checkStatus); err == nil { + // Success! + return rep, nil + } + + // Failure! Retry as needed. + errs = append(errs, err) + + // Compute the backoff delay before the next request + dur := delay.NextBackOff() + if dur == backoff.Stop { + // Stop indicates no more retries should be allowed. + return rep, JoinStatusErrors(attempts, time.Since(start), errs...) + } + + // Wait for backoff delay or until context is canceled + wait := time.After(dur) + select { + case <-ctx.Done(): + errs = append(errs, ctx.Err()) + return rep, JoinStatusErrors(attempts, time.Since(start), errs...) + case <-wait: + continue + } + } + + return rep, JoinStatusErrors(attempts, time.Since(start), errs...) +} + +func (s *APIv1) do(req *http.Request, data interface{}, checkStatus bool) (rep *http.Response, err error) { if rep, err = s.client.Do(req); err != nil { return rep, err } @@ -189,8 +250,7 @@ func (s *APIv1) Do(req *http.Request, data interface{}, checkStatus bool) (rep * return rep, NewStatusError(rep.StatusCode, reply.Error) } } - - return rep, errors.New(rep.Status) + return rep, NewStatusError(rep.StatusCode, rep.Status) } } diff --git a/pkg/api/v1/errors.go b/pkg/api/v1/errors.go index 8936752..20b7d28 100644 --- a/pkg/api/v1/errors.go +++ b/pkg/api/v1/errors.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "time" "github.com/gin-gonic/gin" ) @@ -15,21 +16,9 @@ var ( notAllowed = Reply{Success: false, Error: "method not allowed"} ErrEndpointRequired = errors.New("endpoint is required") ErrIDRequired = errors.New("missing ID in request") + ErrInvalidRetries = errors.New("number of retries must be zero or more") ) -func NewStatusError(code int, err string) error { - return &StatusError{Code: code, Err: err} -} - -type StatusError struct { - Code int - Err string -} - -func (e StatusError) Error() string { - return fmt.Sprintf("[%d]: %s", e.Code, e.Err) -} - // ErrorResponse constructs an new response from the error or returns a success: false. func ErrorResponse(err interface{}) Reply { if err == nil { @@ -57,6 +46,73 @@ func ErrorResponse(err interface{}) Reply { return rep } +func NewStatusError(code int, err string) error { + return &StatusError{Code: code, Err: err} +} + +type StatusError struct { + Code int + Err string +} + +func (e StatusError) Error() string { + return fmt.Sprintf("[%d]: %s", e.Code, e.Err) +} + +// Deduplicates status errors and creates a multi-status error to return. Removes nil +// errors and returns nil if all errs are nil. If only one errors is returned, return +// that error instead of a multierror (e.g. if all responses have the same status code). +func JoinStatusErrors(attempts int, delay time.Duration, errs ...error) error { + err := &MultiStatusError{ + Errs: make([]error, 0), + Attempts: attempts, + } + + seen := make(map[string]struct{}) + for _, e := range errs { + if e == nil { + continue + } + + if _, ok := seen[e.Error()]; ok { + continue + } + + err.Errs = append(err.Errs, e) + seen[e.Error()] = struct{}{} + } + + switch len(err.Errs) { + case 0: + return nil + case 1: + return err.Errs[0] + default: + return err + } +} + +type MultiStatusError struct { + Errs []error + Attempts int + Delay time.Duration +} + +func (e *MultiStatusError) Error() string { + return fmt.Sprintf("after %d attempts: %s", e.Attempts, e.Last()) +} + +func (e *MultiStatusError) Last() error { + if len(e.Errs) > 0 { + return e.Errs[len(e.Errs)-1] + } + return nil +} + +func (e *MultiStatusError) Unwrap() []error { + return e.Errs +} + // NotFound returns a standard 404 response. func NotFound(c *gin.Context) { c.JSON(http.StatusNotFound, notFound) diff --git a/pkg/api/v1/errors_test.go b/pkg/api/v1/errors_test.go new file mode 100644 index 0000000..8b97634 --- /dev/null +++ b/pkg/api/v1/errors_test.go @@ -0,0 +1,74 @@ +package api_test + +import ( + "errors" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/trisacrypto/courier/pkg/api/v1" +) + +func TestJoinStatusErrors(t *testing.T) { + t.Run("Empty", func(t *testing.T) { + err := api.JoinStatusErrors(0, 0, nil) + require.NoError(t, err, "expected a nil error returned") + + err = api.JoinStatusErrors(0, 0, nil, nil, nil, nil, nil, nil) + require.NoError(t, err, "expected a nil error returned for multiple nil errors") + }) + + t.Run("SingleStatusError", func(t *testing.T) { + err := api.JoinStatusErrors(1, 421*time.Millisecond, api.NewStatusError(http.StatusServiceUnavailable, "could not reach specified service")) + require.Error(t, err, "expected error to be returned") + + serr, ok := err.(*api.StatusError) + require.True(t, ok, "expected error to be a status error, not a multi status error") + require.Equal(t, 503, serr.Code) + }) + + t.Run("SingleError", func(t *testing.T) { + err := api.JoinStatusErrors(1, 421*time.Millisecond, errors.New("something went wrong")) + require.Error(t, err, "expected error to be returned") + + _, ok := err.(*api.StatusError) + require.False(t, ok, "expected error to not be a status error") + require.EqualError(t, err, "something went wrong") + }) + + t.Run("MultiStatusErrors", func(t *testing.T) {}) + + t.Run("MultiErrors", func(t *testing.T) {}) + + t.Run("Mixed", func(t *testing.T) {}) + + t.Run("Deduplication", func(t *testing.T) {}) + + t.Run("MultiDeduplication", func(t *testing.T) {}) +} + +func TestMultiStatusError(t *testing.T) { + testCases := []struct { + err *api.MultiStatusError + expected string + }{ + { + &api.MultiStatusError{ + Attempts: 1, + Delay: 585 * time.Millisecond, + Errs: []error{ + &api.StatusError{ + Code: http.StatusInternalServerError, + Err: http.StatusText(http.StatusInternalServerError), + }, + }, + }, + "after 1 attempts: [500]: Internal Server Error", + }, + } + + for i, tc := range testCases { + require.EqualError(t, tc.err, tc.expected, "test case %d failed", i) + } +} diff --git a/pkg/api/v1/options.go b/pkg/api/v1/options.go index 05fe69d..7baa5b2 100644 --- a/pkg/api/v1/options.go +++ b/pkg/api/v1/options.go @@ -1,4 +1,69 @@ package api +import ( + "crypto/tls" + "net/http" + "time" + + "github.com/cenkalti/backoff/v4" +) + // ClientOption allows the API client to be configured when it is created. type ClientOption func(c *APIv1) error + +// BackoffFactory creates a new backoff delay for a specific request. +type BackoffFactory func() backoff.BackOff + +// WithBackoff allows the user to create a client that retries requests with a fixed or +// exponential backoff to allow the remote service time to recover. By default, the +// courier client uses exponential backoff and three retries. +func WithBackoff(bf BackoffFactory) ClientOption { + return func(c *APIv1) error { + c.backoff = bf + return nil + } +} + +// WithZeroBackoff creates a client that retries immediately without delay. +func WithZeroBackoff() ClientOption { + return func(c *APIv1) error { + c.backoff = func() backoff.BackOff { + return &backoff.ZeroBackOff{} + } + return nil + } +} + +// WithRetries allows the user to create a client that retries requests for the +// specified number of attempts. Set to zero to only send one request with no retries. +// The default number of retry attempts is 3. +func WithRetries(attempts int) ClientOption { + return func(c *APIv1) error { + if attempts < 0 { + return ErrInvalidRetries + } + + c.retries = attempts + return nil + } +} + +// WithTLSConfig allows the user to specify a custom tls configuration for the client. +func WithTLSConfig(conf *tls.Config) ClientOption { + return func(c *APIv1) error { + if c.client != nil { + c.client.Transport = &http.Transport{ + TLSClientConfig: conf, + } + } else { + c.client = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: conf, + }, + CheckRedirect: nil, + Timeout: 30 * time.Second, + } + } + return nil + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 79b6179..cbffd57 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -9,12 +9,12 @@ import ( ) type Config struct { - BindAddr string `split_words:"true" default:":8842"` - Mode string `split_words:"true" default:"release"` - MTLS MTLSConfig `split_words:"true"` - LocalStorage LocalStorageConfig `split_words:"true"` - SecretManager SecretsConfig `split_words:"true"` - processed bool + BindAddr string `split_words:"true" default:":8842"` + Mode string `split_words:"true" default:"release"` + MTLS MTLSConfig `split_words:"true"` + LocalStorage LocalStorageConfig `split_words:"true"` + GCPSecretManager GCPSecretsConfig `split_words:"true"` + processed bool } type MTLSConfig struct { @@ -30,7 +30,7 @@ type LocalStorageConfig struct { Path string `split_words:"true"` } -type SecretsConfig struct { +type GCPSecretsConfig struct { Enabled bool `split_words:"true" default:"false"` Credentials string `split_words:"true"` Project string `split_words:"true"` @@ -76,11 +76,11 @@ func (c Config) Validate() (err error) { return err } - if !c.LocalStorage.Enabled && !c.SecretManager.Enabled { + if !c.LocalStorage.Enabled && !c.GCPSecretManager.Enabled { return ErrNoStorageEnabled } - if c.LocalStorage.Enabled && c.SecretManager.Enabled { + if c.LocalStorage.Enabled && c.GCPSecretManager.Enabled { return ErrMultipleStorageEnabled } @@ -88,7 +88,7 @@ func (c Config) Validate() (err error) { return err } - if err = c.SecretManager.Validate(); err != nil { + if err = c.GCPSecretManager.Validate(); err != nil { return err } @@ -199,7 +199,7 @@ func (c LocalStorageConfig) Validate() (err error) { return nil } -func (c SecretsConfig) Validate() (err error) { +func (c GCPSecretsConfig) Validate() (err error) { if !c.Enabled { return nil } diff --git a/pkg/config/config_test.go b/pkg/config/config_test.go index 0f44bb8..7914fbf 100644 --- a/pkg/config/config_test.go +++ b/pkg/config/config_test.go @@ -10,16 +10,16 @@ import ( // Define a test environment for the config tests. var testEnv = map[string]string{ - "COURIER_BIND_ADDR": ":8080", - "COURIER_MODE": "debug", - "COURIER_MTLS_INSECURE": "false", - "COURIER_MTLS_CERT_PATH": "/path/to/cert", - "COURIER_MTLS_POOL_PATH": "/path/to/pool", - "COURIER_LOCAL_STORAGE_ENABLED": "true", - "COURIER_LOCAL_STORAGE_PATH": "/path/to/storage", - "COURIER_SECRET_MANAGER_ENABLED": "true", - "COURIER_SECRET_MANAGER_CREDENTIALS": "test-credentials", - "COURIER_SECRET_MANAGER_PROJECT": "test-project", + "COURIER_BIND_ADDR": ":8080", + "COURIER_MODE": "debug", + "COURIER_MTLS_INSECURE": "false", + "COURIER_MTLS_CERT_PATH": "/path/to/cert", + "COURIER_MTLS_POOL_PATH": "/path/to/pool", + "COURIER_LOCAL_STORAGE_ENABLED": "true", + "COURIER_LOCAL_STORAGE_PATH": "/path/to/storage", + "COURIER_GCP_SECRET_MANAGER_ENABLED": "true", + "COURIER_GCP_SECRET_MANAGER_CREDENTIALS": "test-credentials", + "COURIER_GCP_SECRET_MANAGER_PROJECT": "test-project", } func TestConfig(t *testing.T) { @@ -47,9 +47,9 @@ func TestConfig(t *testing.T) { require.Equal(t, testEnv["COURIER_MTLS_POOL_PATH"], conf.MTLS.PoolPath) require.True(t, conf.LocalStorage.Enabled) require.Equal(t, testEnv["COURIER_LOCAL_STORAGE_PATH"], conf.LocalStorage.Path) - require.True(t, conf.SecretManager.Enabled) - require.Equal(t, testEnv["COURIER_SECRET_MANAGER_CREDENTIALS"], conf.SecretManager.Credentials) - require.Equal(t, testEnv["COURIER_SECRET_MANAGER_PROJECT"], conf.SecretManager.Project) + require.True(t, conf.GCPSecretManager.Enabled) + require.Equal(t, testEnv["COURIER_GCP_SECRET_MANAGER_CREDENTIALS"], conf.GCPSecretManager.Credentials) + require.Equal(t, testEnv["COURIER_GCP_SECRET_MANAGER_PROJECT"], conf.GCPSecretManager.Project) } func TestValidate(t *testing.T) { @@ -91,7 +91,7 @@ func TestValidate(t *testing.T) { MTLS: config.MTLSConfig{ Insecure: true, }, - SecretManager: config.SecretsConfig{ + GCPSecretManager: config.GCPSecretsConfig{ Enabled: true, Credentials: "test-credentials", Project: "test-project", @@ -153,7 +153,7 @@ func TestValidate(t *testing.T) { Enabled: true, Path: "/path/to/storage", }, - SecretManager: config.SecretsConfig{ + GCPSecretManager: config.GCPSecretsConfig{ Enabled: true, Credentials: "test-credentials", Project: "test-project", @@ -179,7 +179,7 @@ func TestValidate(t *testing.T) { func TestValidateSecretConfig(t *testing.T) { t.Run("ValidSecretConfig", func(t *testing.T) { - conf := config.SecretsConfig{ + conf := config.GCPSecretsConfig{ Enabled: true, Credentials: "test-credentials", Project: "test-project", @@ -188,12 +188,12 @@ func TestValidateSecretConfig(t *testing.T) { }) t.Run("ValidDisabled", func(t *testing.T) { - conf := config.SecretsConfig{} + conf := config.GCPSecretsConfig{} require.NoError(t, conf.Validate(), "expected disabled secret config to be valid") }) t.Run("MissingCredentials", func(t *testing.T) { - conf := config.SecretsConfig{ + conf := config.GCPSecretsConfig{ Enabled: true, Project: "test-project", } @@ -201,7 +201,7 @@ func TestValidateSecretConfig(t *testing.T) { }) t.Run("MissingProject", func(t *testing.T) { - conf := config.SecretsConfig{ + conf := config.GCPSecretsConfig{ Enabled: true, Credentials: "test-credentials", } diff --git a/pkg/secrets/client.go b/pkg/secrets/client.go index fa8fb2a..37cabf7 100644 --- a/pkg/secrets/client.go +++ b/pkg/secrets/client.go @@ -15,7 +15,7 @@ import ( ) // NewClient creates a secret manager client from the configuration. -func NewClient(conf config.SecretsConfig, opts ...SecretsOption) (_ SecretManagerClient, err error) { +func NewClient(conf config.GCPSecretsConfig, opts ...SecretsOption) (_ SecretManagerClient, err error) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/pkg/server.go b/pkg/server.go index 959ba2f..5c17e11 100644 --- a/pkg/server.go +++ b/pkg/server.go @@ -40,8 +40,8 @@ func New(conf config.Config) (s *Server, err error) { if s.store, err = local.Open(conf.LocalStorage); err != nil { return nil, err } - case conf.SecretManager.Enabled: - if s.store, err = gcloud.Open(conf.SecretManager); err != nil { + case conf.GCPSecretManager.Enabled: + if s.store, err = gcloud.Open(conf.GCPSecretManager); err != nil { return nil, err } default: diff --git a/pkg/server_test.go b/pkg/server_test.go index a762458..e5a7640 100644 --- a/pkg/server_test.go +++ b/pkg/server_test.go @@ -52,9 +52,9 @@ func (s *courierTestSuite) SetupSuite() { // Wait for the server to start serving the API time.Sleep(500 * time.Millisecond) - // Create an API client to use in tests + // Create an API client to use in tests (no retries, no backoff for testing errors) url := s.courier.URL() - s.client, err = api.New(url) + s.client, err = api.New(url, api.WithRetries(0), api.WithZeroBackoff()) require.NoError(err, "could not create test client") } diff --git a/pkg/store/gcloud/store.go b/pkg/store/gcloud/store.go index bd104fa..54c2b79 100644 --- a/pkg/store/gcloud/store.go +++ b/pkg/store/gcloud/store.go @@ -10,7 +10,7 @@ import ( ) // Open the google cloud storage backend. -func Open(conf config.SecretsConfig, opts ...StoreOption) (store *Store, err error) { +func Open(conf config.GCPSecretsConfig, opts ...StoreOption) (store *Store, err error) { store = &Store{} // Apply provided options diff --git a/pkg/store/gcloud/store_test.go b/pkg/store/gcloud/store_test.go index a487d14..554ed9b 100644 --- a/pkg/store/gcloud/store_test.go +++ b/pkg/store/gcloud/store_test.go @@ -19,7 +19,7 @@ import ( type gcloudStoreTestSuite struct { suite.Suite store *gcloud.Store - conf config.SecretsConfig + conf config.GCPSecretsConfig sm *mock.SecretManager } @@ -27,7 +27,7 @@ func (s *gcloudStoreTestSuite) SetupSuite() { // Open the storage backend using a mock secrets client var err error s.sm = mock.New() - s.conf = config.SecretsConfig{ + s.conf = config.GCPSecretsConfig{ Enabled: true, Credentials: "creds.json", Project: "project", From 3dc820a994162f9d55a1dd5ce57308ce7a327f88 Mon Sep 17 00:00:00 2001 From: Benjamin Bengfort Date: Wed, 3 Jan 2024 17:38:44 -0600 Subject: [PATCH 2/2] update tests --- pkg/api/v1/client_test.go | 31 +++++++++++++++++ pkg/api/v1/errors.go | 3 ++ pkg/api/v1/errors_test.go | 70 ++++++++++++++++++++++++++++++++++----- 3 files changed, 96 insertions(+), 8 deletions(-) diff --git a/pkg/api/v1/client_test.go b/pkg/api/v1/client_test.go index 73db215..ac91ad8 100644 --- a/pkg/api/v1/client_test.go +++ b/pkg/api/v1/client_test.go @@ -4,8 +4,11 @@ import ( "context" "net/http" "net/http/httptest" + "sync/atomic" "testing" + "time" + "github.com/cenkalti/backoff/v4" "github.com/stretchr/testify/require" "github.com/trisacrypto/courier/pkg/api/v1" ) @@ -63,3 +66,31 @@ func TestStoreCertificatePassword(t *testing.T) { err = client.StoreCertificatePassword(context.Background(), req) require.ErrorIs(t, err, api.ErrIDRequired, "client should error if no ID is provided") } + +func TestRetriesWithBackoff(t *testing.T) { + // Create a test server + var attempts uint32 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddUint32(&attempts, 1) + http.Error(w, http.StatusText(http.StatusTooEarly), http.StatusTooEarly) + })) + defer ts.Close() + + // Create a client to test the client method + client, err := api.New(ts.URL, api.WithRetries(10), api.WithBackoff(func() backoff.BackOff { + return backoff.NewConstantBackOff(100 * time.Millisecond) + })) + require.NoError(t, err, "could not create client") + + rawClient, ok := client.(*api.APIv1) + require.True(t, ok, "expected client to be an APIv1 client") + + req, err := rawClient.NewRequest(context.Background(), http.MethodGet, "/", nil, nil) + require.NoError(t, err, "could not create request") + + start := time.Now() + _, err = rawClient.Do(req, nil, true) + require.Error(t, err, "expected an error to be returned") + require.Equal(t, uint32(11), attempts, "expected 10 retry attempts") + require.Greater(t, time.Since(start), 950*time.Millisecond, "expected backoff delay") +} diff --git a/pkg/api/v1/errors.go b/pkg/api/v1/errors.go index 20b7d28..50b374e 100644 --- a/pkg/api/v1/errors.go +++ b/pkg/api/v1/errors.go @@ -47,6 +47,9 @@ func ErrorResponse(err interface{}) Reply { } func NewStatusError(code int, err string) error { + if err == "" { + err = http.StatusText(code) + } return &StatusError{Code: code, Err: err} } diff --git a/pkg/api/v1/errors_test.go b/pkg/api/v1/errors_test.go index 8b97634..cf5e761 100644 --- a/pkg/api/v1/errors_test.go +++ b/pkg/api/v1/errors_test.go @@ -20,12 +20,12 @@ func TestJoinStatusErrors(t *testing.T) { }) t.Run("SingleStatusError", func(t *testing.T) { - err := api.JoinStatusErrors(1, 421*time.Millisecond, api.NewStatusError(http.StatusServiceUnavailable, "could not reach specified service")) + err := api.JoinStatusErrors(1, 421*time.Millisecond, api.NewStatusError(http.StatusServiceUnavailable, "")) require.Error(t, err, "expected error to be returned") - serr, ok := err.(*api.StatusError) + _, ok := err.(*api.StatusError) require.True(t, ok, "expected error to be a status error, not a multi status error") - require.Equal(t, 503, serr.Code) + require.EqualError(t, err, "[503]: Service Unavailable") }) t.Run("SingleError", func(t *testing.T) { @@ -37,15 +37,69 @@ func TestJoinStatusErrors(t *testing.T) { require.EqualError(t, err, "something went wrong") }) - t.Run("MultiStatusErrors", func(t *testing.T) {}) + t.Run("MultiStatusErrors", func(t *testing.T) { + err := api.JoinStatusErrors(3, 1829*time.Millisecond, + api.NewStatusError(http.StatusUnauthorized, ""), + api.NewStatusError(http.StatusServiceUnavailable, ""), + api.NewStatusError(http.StatusInsufficientStorage, ""), + ) + require.Error(t, err, "expected error to be returned") - t.Run("MultiErrors", func(t *testing.T) {}) + _, ok := err.(*api.MultiStatusError) + require.True(t, ok, "expected error to be a multi-status error") + require.EqualError(t, err, "after 3 attempts: [507]: Insufficient Storage") + }) - t.Run("Mixed", func(t *testing.T) {}) + t.Run("MultiErrors", func(t *testing.T) { + err := api.JoinStatusErrors(2, 727*time.Millisecond, + errors.New("oopsie"), errors.New("something went wrong"), + ) + require.Error(t, err, "expected error to be returned") - t.Run("Deduplication", func(t *testing.T) {}) + _, ok := err.(*api.MultiStatusError) + require.True(t, ok, "expected error to be a multi-status error") + require.EqualError(t, err, "after 2 attempts: something went wrong") + }) + + t.Run("Mixed", func(t *testing.T) { + err := api.JoinStatusErrors(2, 3217*time.Millisecond, + api.NewStatusError(http.StatusServiceUnavailable, ""), + errors.New("something went wrong"), + ) + require.Error(t, err, "expected error to be returned") + + _, ok := err.(*api.MultiStatusError) + require.True(t, ok, "expected error to be a multi-status error") + require.EqualError(t, err, "after 2 attempts: something went wrong") + }) - t.Run("MultiDeduplication", func(t *testing.T) {}) + t.Run("Deduplication", func(t *testing.T) { + err := api.JoinStatusErrors(3, 2451*time.Millisecond, + api.NewStatusError(http.StatusServiceUnavailable, ""), + api.NewStatusError(http.StatusServiceUnavailable, ""), + api.NewStatusError(http.StatusServiceUnavailable, ""), + ) + require.Error(t, err, "expected error to be returned") + + _, ok := err.(*api.StatusError) + require.True(t, ok, "expected error to be a status error") + require.EqualError(t, err, "[503]: Service Unavailable") + }) + + t.Run("MultiDeduplication", func(t *testing.T) { + err := api.JoinStatusErrors(5, 3257*time.Millisecond, + api.NewStatusError(http.StatusUnauthorized, ""), + api.NewStatusError(http.StatusServiceUnavailable, ""), + api.NewStatusError(http.StatusUnauthorized, ""), + api.NewStatusError(http.StatusInsufficientStorage, ""), + api.NewStatusError(http.StatusServiceUnavailable, ""), + ) + require.Error(t, err, "expected error to be returned") + + _, ok := err.(*api.MultiStatusError) + require.True(t, ok, "expected error to be a multi-status error") + require.EqualError(t, err, "after 5 attempts: [507]: Insufficient Storage") + }) } func TestMultiStatusError(t *testing.T) {