Skip to content

Commit

Permalink
add option to pass retryable response handler
Browse files Browse the repository at this point in the history
  • Loading branch information
gavriel-hc committed Apr 11, 2022
1 parent af0a5a3 commit 5e7d7d1
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 3 deletions.
25 changes: 22 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,11 @@ func PassthroughErrorHandler(resp *http.Response, err error, _ int) (*http.Respo

// Do wraps calling an HTTP method with retries.
func (c *Client) Do(req *Request) (*http.Response, error) {
return c.DoWithResponseHandler(req, nil)
}

// DoWithResponseHandler wraps calling an HTTP method plus a response handler with retries.
func (c *Client) DoWithResponseHandler(req *Request, handler func(*http.Response) (shouldRetry bool)) (*http.Response, error) {
c.clientInit.Do(func() {
if c.HTTPClient == nil {
c.HTTPClient = cleanhttp.DefaultPooledClient()
Expand Down Expand Up @@ -606,9 +611,6 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
// Attempt the request
resp, doErr = c.HTTPClient.Do(req.Request)

// Check if we should continue with retries.
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)

if doErr != nil {
switch v := logger.(type) {
case LeveledLogger:
Expand All @@ -632,6 +634,13 @@ func (c *Client) Do(req *Request) (*http.Response, error) {
}
}

// Check if we should continue with retries.
shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr)

successSoFar := !shouldRetry && doErr == nil && checkErr == nil
if successSoFar && handler != nil {
shouldRetry = handler(resp)
}
if !shouldRetry {
break
}
Expand Down Expand Up @@ -739,6 +748,16 @@ func (c *Client) Get(url string) (*http.Response, error) {
return c.Do(req)
}

// GetWithResponseHandler is a helper for doing a GET request followed by a function on the response.
// The intention is for this to be used when errors in the response handling should also be retried.
func (c *Client) GetWithResponseHandler(url string, handler func(*http.Response) (shouldRetry bool)) (*http.Response, error) {
req, err := NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
return c.DoWithResponseHandler(req, handler)
}

// Head is a shortcut for doing a HEAD request without making a new client.
func Head(url string) (*http.Response, error) {
return defaultClient.Head(url)
Expand Down
77 changes: 77 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,83 @@ func testClientDo(t *testing.T, body interface{}) {
}
}

func TestClient_DoWithHandler(t *testing.T) {
// Create the client. Use short retry windows so we fail faster.
client := NewClient()
client.RetryWaitMin = 10 * time.Millisecond
client.RetryWaitMax = 10 * time.Millisecond
client.RetryMax = 2

var attempts int
client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) {
attempts++
return DefaultRetryPolicy(context.TODO(), resp, err)
}

// Mock server which always responds 200.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
}))
defer ts.Close()

alternatingBool := false
tests := []struct {
name string
handler func(*http.Response) bool
expectedAttempts int
err string
}{
{
name: "nil handler",
handler: nil,
expectedAttempts: 1,
},
{
name: "handler never should retry",
handler: func(*http.Response) bool { return false },
expectedAttempts: 1,
},
{
name: "handler alternates should retry",
handler: func(*http.Response) bool {
alternatingBool = !alternatingBool
return alternatingBool
},
expectedAttempts: 2,
},
{
name: "handler always should retry",
handler: func(*http.Response) bool { return true },
expectedAttempts: 3,
err: "giving up after 3 attempt(s)",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
attempts = 0
// Create the request
req, err := NewRequest("GET", ts.URL, nil)
if err != nil {
t.Fatalf("err: %v", err)
}

// Send the request.
_, err = client.DoWithResponseHandler(req, tt.handler)
if err != nil && !strings.Contains(err.Error(), tt.err) {
t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error())
}
if err == nil && tt.err != "" {
t.Fatalf("no error, expected: %s", tt.err)
}

if attempts != tt.expectedAttempts {
t.Fatalf("expected %d attempts, got %d attempts", tt.expectedAttempts, attempts)
}
})
}
}

func TestClient_Do_fails(t *testing.T) {
// Mock server which always responds 500.
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down

0 comments on commit 5e7d7d1

Please sign in to comment.