diff --git a/internal/client/client.go b/internal/client/client.go index 9cd6de59..cae7cac6 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -2,11 +2,13 @@ package client import ( "bytes" + "context" "crypto/hmac" "crypto/sha1" // nolint:gosec "encoding/hex" "encoding/json" "fmt" + "io" "io/ioutil" "net/http" "net/http/httputil" @@ -27,6 +29,17 @@ const ( SharedSecret = "complement" ) +type CtxKey string + +const ( + CtxKeyWithRetryUntil CtxKey = "complement_retry_until" // contains *retryUntilParams +) + +type retryUntilParams struct { + timeout time.Duration + untilFn func(*http.Response) bool +} + // RequestOpt is a functional option which will modify an outgoing HTTP request. // See functions starting with `With...` in this package for more info. type RequestOpt func(req *http.Request) @@ -441,6 +454,16 @@ func WithQueries(q url.Values) RequestOpt { } } +// WithRetryUntil will retry the request until the provided function returns true. Times out after +// `timeout`, which will then fail the test. +func WithRetryUntil(timeout time.Duration, untilFn func(res *http.Response) bool) RequestOpt { + return func(req *http.Request) { + until := req.Context().Value(CtxKeyWithRetryUntil).(*retryUntilParams) + until.timeout = timeout + until.untilFn = untilFn + } +} + // MustDoFunc is the same as DoFunc but fails the test if the returned HTTP response code is not 2xx. func (c *CSAPI) MustDoFunc(t *testing.T, method string, paths []string, opts ...RequestOpt) *http.Response { t.Helper() @@ -479,6 +502,9 @@ func (c *CSAPI) DoFunc(t *testing.T, method string, paths []string, opts ...Requ if c.AccessToken != "" { req.Header.Set("Authorization", "Bearer "+c.AccessToken) } + retryUntil := &retryUntilParams{} + ctx := context.WithValue(req.Context(), CtxKeyWithRetryUntil, retryUntil) + req = req.WithContext(ctx) // set functional options for _, o := range opts { @@ -502,21 +528,48 @@ func (c *CSAPI) DoFunc(t *testing.T, method string, paths []string, opts ...Requ t.Logf("Request body: ", contentType) } } - // Perform the HTTP request - res, err := c.Client.Do(req) - if err != nil { - t.Fatalf("CSAPI.DoFunc response returned error: %s", err) - } - // debug log the response - if c.Debug && res != nil { - var dump []byte - dump, err = httputil.DumpResponse(res, true) + now := time.Now() + for { + // Perform the HTTP request + res, err := c.Client.Do(req) if err != nil { - t.Fatalf("CSAPI.DoFunc failed to dump response body: %s", err) + t.Fatalf("CSAPI.DoFunc response returned error: %s", err) + } + // debug log the response + if c.Debug && res != nil { + var dump []byte + dump, err = httputil.DumpResponse(res, true) + if err != nil { + t.Fatalf("CSAPI.DoFunc failed to dump response body: %s", err) + } + t.Logf("%s", string(dump)) + } + if retryUntil == nil || retryUntil.timeout == 0 { + return res // don't retry + } + + // check the condition, make a copy of the response body first in case the check consumes it + var resBody []byte + if res.Body != nil { + resBody, err = ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("CSAPI.DoFunc failed to read response body for RetryUntil check: %s", err) + } + res.Body = io.NopCloser(bytes.NewBuffer(resBody)) } - t.Logf("%s", string(dump)) + if retryUntil.untilFn(res) { + // remake the response and return + res.Body = io.NopCloser(bytes.NewBuffer(resBody)) + return res + } + // condition not satisfied, do we timeout yet? + if time.Since(now) > retryUntil.timeout { + t.Fatalf("CSAPI.DoFunc RetryUntil: %v %v timed out after %v", method, req.URL, retryUntil.timeout) + } + t.Logf("CSAPI.DoFunc RetryUntil: %v %v response condition not yet met, retrying", method, req.URL) + // small sleep to avoid tight-looping + time.Sleep(100 * time.Millisecond) } - return res } // NewLoggedClient returns an http.Client which logs requests/responses