Skip to content

Commit

Permalink
Add tests to verify retry behaviour with POST requests
Browse files Browse the repository at this point in the history
* Add test for POST-requests provided by Do() with a body
* Add test for POST-requests provided by Post() with a body
* Add middlewareServer to support POST tests
  • Loading branch information
lvanoort committed Feb 8, 2022
1 parent 26b8a6d commit c1c03d2
Showing 1 changed file with 158 additions and 0 deletions.
158 changes: 158 additions & 0 deletions pester_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -806,7 +807,106 @@ func TestRetriesContextCancelledDuringWait(t *testing.T) {
// give it a second, then treat this as failing to return
t.Fatal("failed to receive error return")
}
}

func TestRetriesWithBodies_Do(t *testing.T) {
t.Parallel()

const testContent = "TestRetriesWithBodies_Do"
// using a channel to route these errors back into this goroutine
// it is important that this channel have enough capacity to hold all
// of the errors that will be generated by the test so that we do not
// deadlock. Therefore, MaxAttempts must be the same size as the channel capacity
// and each execution must only put at most one error on the channel.
serverReqErrCh := make(chan error, 4)
port, closeFn, err := middlewareServer(
contentVerificationMiddleware(serverReqErrCh, testContent),
always500RequestMiddleware(),
)
if err != nil {
t.Fatal("unable to start timeout server", err)
}
defer closeFn()

<-time.After(2 * time.Second)

iseUrl := fmt.Sprintf("http://localhost:%d", port)

req, err := http.NewRequest("POST", iseUrl, strings.NewReader(testContent))
if err != nil {
t.Fatalf("unable to create request %v", err)
}

c := New()
c.MaxRetries = cap(serverReqErrCh)
c.KeepLog = true
c.Backoff = func(retry int) time.Duration {
// backoff isn't important for this test
return 0
}

resp, err := c.Do(req)
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp == nil {
t.Error("response was unexpectedly nil")
} else if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("unexpected response StatusCode: %v", resp.StatusCode)
}
// we're done making requests, so close the return channel and drain it
close(serverReqErrCh)
for v := range serverReqErrCh {
if v != nil {
t.Errorf("unexpected error occurred when server processed request: %v", v)
}
}
}

func TestRetriesWithBodies_POST(t *testing.T) {
t.Parallel()

const testContent = "TestRetriesWithBodies_POST"
// using a channel to route these errors back into this goroutine
// it is important that this channel have enough capacity to hold all
// of the errors that will be generated by the test so that we do not
// deadlock. Therefore, MaxAttempts must be the same size as the channel capacity
// and each execution must only put at most one error on the channel.
serverReqErrCh := make(chan error, 4)
port, closeFn, err := middlewareServer(
contentVerificationMiddleware(serverReqErrCh, testContent),
always500RequestMiddleware(),
)
if err != nil {
t.Fatal("unable to start timeout server", err)
}
defer closeFn()

c := New()
c.MaxRetries = cap(serverReqErrCh)
c.KeepLog = true
c.Backoff = func(retry int) time.Duration {
// backoff isn't important for this test
return 0
}

iseUrl := fmt.Sprintf("http://localhost:%d", port)
resp, err := c.Post(iseUrl, "text/plain", strings.NewReader(testContent))
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if resp == nil {
t.Error("response was unexpectedly nil")
} else if resp.StatusCode != http.StatusInternalServerError {
t.Errorf("unexpected response StatusCode: %v", resp.StatusCode)
}
// we're done making requests, so close the return channel and drain it
close(serverReqErrCh)
for v := range serverReqErrCh {
if v != nil {
t.Errorf("unexpected error occurred when server processed request: %v", v)
}
}
}

func withinEpsilon(got, want int64, epslion float64) bool {
Expand Down Expand Up @@ -946,3 +1046,61 @@ func serverWith400() (int, error) {

return port, nil
}

func contentVerificationMiddleware(errorCh chan<- error, expectedContent string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
content, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
errorCh <- err
} else if string(content) != expectedContent {
errorCh <- fmt.Errorf(
"unexpected body content: expected \"%v\", got \"%v\"",
expectedContent,
string(content),
)
}
})
}

func always500RequestMiddleware() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte("500 Internal Server Error"))
})
}

// middlewareServer stands up a server that accepts varags of middleware that conforms to the
// http.Handler interface
func middlewareServer(requestMiddleware ...http.Handler) (int, func(), error) {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
for _, v := range requestMiddleware {
v.ServeHTTP(w, r)
}
})
l, err := net.Listen("tcp", ":0")
if err != nil {
return -1, nil, fmt.Errorf("unable to secure listener %v", err)
}
server := &http.Server{
Handler: mux,
}
go func() {
if err := server.Serve(l); err != nil && err != http.ErrServerClosed {
log.Fatalf("middleware-server error %v", err)
}
}()

var port int
_, sport, err := net.SplitHostPort(l.Addr().String())
if err == nil {
port, err = strconv.Atoi(sport)
}

if err != nil {
return -1, nil, fmt.Errorf("unable to determine port %v", err)
}

return port, func() { server.Close() }, nil
}

0 comments on commit c1c03d2

Please sign in to comment.