Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix POST retries #45

Merged
merged 5 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 68 additions & 20 deletions pester.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,16 @@ func (c *Client) copyBody(src io.ReadCloser) ([]byte, error) {
return b, nil
}

// resetBody resets the Body and GetBody fields of an http.Request to new Readers over
// the originalBody. This is used to refresh http.Requests that may have had their
// bodies closed already.
func resetBody(request *http.Request, originalBody []byte) {
request.Body = io.NopCloser(bytes.NewBuffer(originalBody))
request.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(bytes.NewBuffer(originalBody)), nil
}
}

// pester provides all the logic of retries, concurrency, backoff, and logging
func (c *Client) pester(p params) (*http.Response, error) {
resultCh := make(chan result)
Expand Down Expand Up @@ -242,7 +252,6 @@ func (c *Client) pester(p params) (*http.Response, error) {

// if we have a request body, we need to save it for later
var (
request *http.Request
originalBody []byte
err error
)
Expand All @@ -252,23 +261,52 @@ func (c *Client) pester(p params) (*http.Response, error) {
} else if p.body != nil {
originalBody, err = c.copyBody(p.body)
}
if err != nil {
return nil, err
}

// check to make sure that we aren't trying to use an unsupported method
switch p.method {
case methodDo:
request = p.req
case methodGet, methodHead:
request, err = http.NewRequest(p.verb, p.url, nil)
case methodPostForm, methodPost:
request, err = http.NewRequest(http.MethodPost, p.url, ioutil.NopCloser(bytes.NewBuffer(originalBody)))
case methodDo, methodGet, methodHead, methodPostForm, methodPost:
default:
err = ErrUnexpectedMethod
}
if err != nil {
return nil, err
return nil, ErrUnexpectedMethod
}

if len(p.bodyType) > 0 {
request.Header.Set(headerKeyContentType, p.bodyType)
// provideRequest returns an HTTP request to be use when retrying.
// if concurrency is 1, it will return the same request that was supplied to the Do() method
// for Do() calls, otherwise it will generate a Clone() of the request each time it is called.
// For non-Do() calls, it creates a new request each time it is called. This re-creation behaviour
// is because requests are not supposed to be used again until the RoundTripper is finished
// with them, which cannot be guaranteed with concurrent callers
// https://pkg.go.dev/net/http#RoundTripper
provideRequest := func() (request *http.Request, err error) {
switch p.method {
case methodDo:
if concurrency > 1 {
request = p.req.Clone(p.req.Context())
} else {
request = p.req
}
if request.Body != nil {
// reset the body since Clone() doesn't do that for us
// and we drained it earlier when performing the Copy
// ex: https://go.dev/play/p/jlc6A-fjaOi
resetBody(request, originalBody)
}
case methodGet, methodHead:
request, err = http.NewRequest(p.verb, p.url, nil)
case methodPostForm, methodPost:
request, err = http.NewRequest(http.MethodPost, p.url, bytes.NewBuffer(originalBody))
}
if err != nil {
return
}

if len(p.bodyType) > 0 {
request.Header.Set(headerKeyContentType, p.bodyType)
}

return
}

AttemptLimit := c.MaxRetries
Expand All @@ -279,9 +317,15 @@ func (c *Client) pester(p params) (*http.Response, error) {
for n := 0; n < concurrency; n++ {
c.wg.Add(1)
totalSentRequests.Add(1)
go func(n int, req *http.Request) {
go func(n int) {
defer c.wg.Done()
defer totalSentRequests.Done()
req, err := provideRequest()
// couldn't get a request to use, so don't proceed
if err != nil {
multiplexCh <- result{err: err, req: n}
return
}

for i := 1; i <= AttemptLimit; i++ {
c.wg.Add(1)
Expand Down Expand Up @@ -340,15 +384,19 @@ func (c *Client) pester(p params) (*http.Response, error) {
case <-time.After(c.Backoff(i) + 1*time.Microsecond):
// allow context cancellation to cancel during backoff
case <-req.Context().Done():
multiplexCh <- result{resp: resp, err: req.Context().Err()}
return
}
}
}(n, request)

// rehydrate the body (it is drained each read)
if request.Body != nil {
request.Body = ioutil.NopCloser(bytes.NewBuffer(originalBody))
}
// we are about to retry, if we had a Body, we will need to restore it
// to a non-closed one in order to work reliably. If you do not do this,
// there are a number of curious edge cases depending on the type of the
// underlying reader: https://go.dev/play/p/gZLVUe2EXSE
if req.Body != nil {
resetBody(req, originalBody)
}
}
}(n)
}

// spin off the go routine so it can continually listen in on late results and close the response bodies
Expand Down
224 changes: 224 additions & 0 deletions pester_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -743,6 +744,171 @@ func TestRetriesNotAttemptedIfContextIsCancelled(t *testing.T) {
}
}

type roundTripperFunc func(r *http.Request) (*http.Response, error)

func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}

func TestRetriesContextCancelledDuringWait(t *testing.T) {
t.Parallel()
// in order for this test to work we need to be able to reliably put the client in a
// waiting state. To achieve this, we create a client that will fail fast
// via a custom RoundTripper that always fails and pair it with a custom BackoffStrategy
// that waits for a long time. This results in a client that should spend
// almost all of its time waiting.

ctx, cancel := context.WithCancel(context.Background())

c := NewExtendedClient(&http.Client{
Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("always fail")
}),
Timeout: 5 * time.Second,
})
c.MaxRetries = 2
c.Backoff = func(retry int) time.Duration {
return 5 * time.Second
}
// req details don't really matter, round-tripper will fail it anyway
req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost", nil)
if err != nil {
t.Fatalf("unable to create request %v", err)
}

// we want to perform the call in a goroutine so we can explicitly check for indefinite
// blocking behaviour. Since you cannot use t.Fatal/t.Error/etc. in a goroutine, we
// create a channel to communicate back to our main goroutine what happened
errReturn := make(chan error)
go func() {
// perform call in goroutine to check for indefinite blocks
_, err := c.Do(req)
errReturn <- err
}()

// wait a hundred ms to let the client fail and get into a waiting state
<-time.After(100 * time.Millisecond)
// cancel our context
cancel()

// if all has gone well, we should have aborted our wait period and the
// err channel should contain a Context-cancellation error

select {
case recdErr := <-errReturn:
if recdErr == nil {
t.Fatal("nil error returned from Do(req) routine")
}
// check that it is the right error message
if context.Canceled != recdErr {
t.Fatalf("unexpected error returned: %v", recdErr)
}
case <-time.After(time.Second):
// give it a second, then treat this as failing to return
t.Fatal("failed to receive error return")
}
}

func TestRetriesWithBodies_Do(t *testing.T) {
t.Parallel()

const testContent = "TestRetriesWithBodies_Do"
// using a channel to route these errors back into this goroutine
// it is important that this channel have enough capacity to hold all
// of the errors that will be generated by the test so that we do not
// deadlock. Therefore, MaxAttempts must be the same size as the channel capacity
// and each execution must only put at most one error on the channel.
serverReqErrCh := make(chan error, 4)
port, closeFn, err := middlewareServer(
contentVerificationMiddleware(serverReqErrCh, testContent),
always500RequestMiddleware(),
)
if err != nil {
t.Fatal("unable to start timeout server", err)
}
defer closeFn()

<-time.After(2 * time.Second)

iseUrl := fmt.Sprintf("http://localhost:%d", port)

req, err := http.NewRequest("POST", iseUrl, strings.NewReader(testContent))
if err != nil {
t.Fatalf("unable to create request %v", err)
}

c := New()
c.MaxRetries = cap(serverReqErrCh)
c.KeepLog = true
c.Backoff = func(retry int) time.Duration {
// backoff isn't important for this test
return 0
}

resp, err := c.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp == nil {
t.Error("response was unexpectedly nil")
} else if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("unexpected response StatusCode: %v", resp.StatusCode)
}
// we're done making requests, so close the return channel and drain it
close(serverReqErrCh)
for v := range serverReqErrCh {
if v != nil {
t.Errorf("unexpected error occurred when server processed request: %v", v)
}
}
}

func TestRetriesWithBodies_POST(t *testing.T) {
t.Parallel()

const testContent = "TestRetriesWithBodies_POST"
// using a channel to route these errors back into this goroutine
// it is important that this channel have enough capacity to hold all
// of the errors that will be generated by the test so that we do not
// deadlock. Therefore, MaxAttempts must be the same size as the channel capacity
// and each execution must only put at most one error on the channel.
serverReqErrCh := make(chan error, 4)
port, closeFn, err := middlewareServer(
contentVerificationMiddleware(serverReqErrCh, testContent),
always500RequestMiddleware(),
)
if err != nil {
t.Fatal("unable to start timeout server", err)
}
defer closeFn()

c := New()
c.MaxRetries = cap(serverReqErrCh)
c.KeepLog = true
c.Backoff = func(retry int) time.Duration {
// backoff isn't important for this test
return 0
}

iseUrl := fmt.Sprintf("http://localhost:%d", port)
resp, err := c.Post(iseUrl, "text/plain", strings.NewReader(testContent))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp == nil {
t.Error("response was unexpectedly nil")
} else if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("unexpected response StatusCode: %v", resp.StatusCode)
}
// we're done making requests, so close the return channel and drain it
close(serverReqErrCh)
for v := range serverReqErrCh {
if v != nil {
t.Errorf("unexpected error occurred when server processed request: %v", v)
}
}
}

func withinEpsilon(got, want int64, epslion float64) bool {
if want <= int64(epslion*float64(got)) || want >= int64(epslion*float64(got)) {
return false
Expand Down Expand Up @@ -880,3 +1046,61 @@ func serverWith400() (int, error) {

return port, nil
}

func contentVerificationMiddleware(errorCh chan<- error, expectedContent string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
content, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
errorCh <- err
} else if string(content) != expectedContent {
errorCh <- fmt.Errorf(
"unexpected body content: expected \"%v\", got \"%v\"",
expectedContent,
string(content),
)
}
})
}

func always500RequestMiddleware() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("500 Internal Server Error"))
})
}

// middlewareServer stands up a server that accepts varags of middleware that conforms to the
// http.Handler interface
func middlewareServer(requestMiddleware ...http.Handler) (int, func(), error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
for _, v := range requestMiddleware {
v.ServeHTTP(w, r)
}
})
l, err := net.Listen("tcp", ":0")
if err != nil {
return -1, nil, fmt.Errorf("unable to secure listener %v", err)
}
server := &http.Server{
Handler: mux,
}
go func() {
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
log.Fatalf("middleware-server error %v", err)
}
}()

var port int
_, sport, err := net.SplitHostPort(l.Addr().String())
if err == nil {
port, err = strconv.Atoi(sport)
}

if err != nil {
return -1, nil, fmt.Errorf("unable to determine port %v", err)
}

return port, func() { server.Close() }, nil
}