Skip to content

Commit

Permalink
feat: add Responder.{SetContentLength,HeaderAdd,HeaderSet} methods
Browse files Browse the repository at this point in the history
Signed-off-by: Maxime Soulé <btik-git@scoubidou.com>
  • Loading branch information
maxatome committed Jan 17, 2023
1 parent db50b76 commit d4ab20c
Show file tree
Hide file tree
Showing 2 changed files with 329 additions and 1 deletion.
164 changes: 163 additions & 1 deletion response.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,153 @@ func (r Responder) Then(next Responder) (x Responder) {
return
}

// SetContentLength returns a new [Responder] based on r that ensures
// the returned [*http.Response] ContentLength field and
// Content-Length header are set to the right value.
//
// If r returns an [*http.Response] with a nil Body or equal to
// [http.NoBody], the length is always set to 0.
//
// If r returned response.Body implements:
//
// Len() int
//
// then the length is set to the Body.Len() returned value. All
// httpmock generated bodies implement this method. Beware that
// [strings.Builder], [strings.Reader], [bytes.Buffer] and
// [bytes.Reader] types used with [io.NopCloser] do not implement
// Len() anymore.
//
// Otherwise, r returned response.Body is entirely copied into an
// internal buffer to get its length, then it is closed. The Body of
// the [*http.Response] returned by the [Responder] returned by
// SetContentLength can then be read again to return its content as
// usual. But keep in mind that each time this [Responder] is called,
// r is called first. So this one has to carefully handle its body: it
// is highly recommended to use [NewRespBodyFromString] or
// [NewRespBodyFromBytes] to set the body once (as
// [NewStringResponder] and [NewBytesResponder] do behind the scene),
// or to build the body each time r is called.
//
// The following calls are all correct:
//
// responder = httpmock.NewStringResponder(200, "BODY").SetContentLength()
// responder = httpmock.NewBytesResponder(200, []byte("BODY")).SetContentLength()
// responder = ResponderFromResponse(&http.Response{
// // build a body once, but httpmock knows how to "rearm" it once read
// Body: NewRespBodyFromString("BODY"),
// StatusCode: 200,
// }).SetContentLength()
// responder = httpmock.Responder(func(req *http.Request) (*http.Response, error) {
// // build a new body for each call
// return &http.Response{
// StatusCode: 200,
// Body: io.NopCloser(strings.NewReader("BODY")),
// }, nil
// }).SetContentLength()
//
// But the following is not correct:
//
// responder = httpmock.ResponderFromResponse(&http.Response{
// StatusCode: 200,
// Body: io.NopCloser(strings.NewReader("BODY")),
// }).SetContentLength()
//
// it will only succeed for the first responder call. The following
// calls will deliver responses with an empty body, as it will already
// been read by the first call.
func (r Responder) SetContentLength() Responder {
return func(req *http.Request) (*http.Response, error) {
resp, err := r(req)
if err != nil {
return nil, err
}
nr := *resp
switch nr.Body {
case nil:
nr.Body = http.NoBody
fallthrough
case http.NoBody:
nr.ContentLength = 0
default:
bl, ok := nr.Body.(interface{ Len() int })
if !ok {
copyBody := &dummyReadCloser{orig: nr.Body}
bl, nr.Body = copyBody, copyBody
}
nr.ContentLength = int64(bl.Len())
}
if nr.Header == nil {
nr.Header = http.Header{}
}
nr.Header = nr.Header.Clone()
nr.Header.Set("Content-Length", strconv.FormatInt(nr.ContentLength, 10))
return &nr, nil
}
}

// HeaderAdd returns a new [Responder] based on r that ensures the
// returned [*http.Response] includes h header. It adds each h entry
// to the header. It appends to any existing values associated with
// each h key. Each key is case insensitive; it is canonicalized by
// [http.CanonicalHeaderKey].
//
// See also [Responder.HeaderSet] and [Responder.SetContentLength].
func (r Responder) HeaderAdd(h http.Header) Responder {
return func(req *http.Request) (*http.Response, error) {
resp, err := r(req)
if err != nil {
return nil, err
}
nr := *resp
if nr.Header == nil {
nr.Header = make(http.Header, len(h))
}
nr.Header = nr.Header.Clone()
for k, v := range h {
k = http.CanonicalHeaderKey(k)
if v == nil {
if _, ok := nr.Header[k]; !ok {
nr.Header[k] = nil
}
continue
}
nr.Header[k] = append(nr.Header[k], v...)
}
return &nr, nil
}
}

// HeaderSet returns a new [Responder] based on r that ensures the
// returned [*http.Response] includes h header. It sets the header
// entries associated with each h key. It replaces any existing values
// associated each h key. Each key is case insensitive; it is
// canonicalized by [http.CanonicalHeaderKey].
//
// See also [Responder.HeaderAdd] and [Responder.SetContentLength].
func (r Responder) HeaderSet(h http.Header) Responder {
return func(req *http.Request) (*http.Response, error) {
resp, err := r(req)
if err != nil {
return nil, err
}
nr := *resp
if nr.Header == nil {
nr.Header = make(http.Header, len(h))
}
nr.Header = nr.Header.Clone()
for k, v := range h {
k = http.CanonicalHeaderKey(k)
if v == nil {
nr.Header[k] = nil
continue
}
nr.Header[k] = append([]string(nil), v...)
}
return &nr, nil
}
}

// ResponderFromResponse wraps an [*http.Response] in a [Responder].
//
// Be careful, except for responses generated by httpmock
Expand Down Expand Up @@ -560,9 +707,14 @@ func NewRespBodyFromBytes(body []byte) io.ReadCloser {
return &dummyReadCloser{orig: body}
}

type lenReadSeeker interface {
io.ReadSeeker
Len() int
}

type dummyReadCloser struct {
orig any // string or []byte
body io.ReadSeeker // instanciated on demand from orig
body lenReadSeeker // instanciated on demand from orig
}

// copy returns a new instance resetting d.body to nil.
Expand All @@ -578,6 +730,11 @@ func (d *dummyReadCloser) setup() {
d.body = strings.NewReader(body)
case []byte:
d.body = bytes.NewReader(body)
case io.ReadCloser:
var buf bytes.Buffer
io.Copy(&buf, body) //nolint: errcheck
body.Close()
d.body = bytes.NewReader(buf.Bytes())
}
}
}
Expand All @@ -592,3 +749,8 @@ func (d *dummyReadCloser) Close() error {
d.body.Seek(0, io.SeekEnd) // nolint: errcheck
return nil
}

func (d *dummyReadCloser) Len() int {
d.setup()
return d.body.Len()
}
166 changes: 166 additions & 0 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io/ioutil" //nolint: staticcheck
"net/http"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
Expand Down Expand Up @@ -586,6 +587,171 @@ func TestResponder_Then(t *testing.T) {
})
}

func TestResponder_SetContentLength(t *testing.T) {
assert, require := td.AssertRequire(t)

req, err := http.NewRequest(http.MethodGet, "http://foo.bar", nil)
require.CmpNoError(err)

testCases := []struct {
name string
r Responder
expLen int
}{
{
name: "nil body",
r: ResponderFromResponse(&http.Response{
StatusCode: 200,
ContentLength: -1,
}),
expLen: 0,
},
{
name: "http.NoBody",
r: ResponderFromResponse(&http.Response{
Body: http.NoBody,
StatusCode: 200,
ContentLength: -1,
}),
expLen: 0,
},
{
name: "string",
r: NewStringResponder(200, "BODY"),
expLen: 4,
},
{
name: "bytes",
r: NewBytesResponder(200, []byte("BODY")),
expLen: 4,
},
{
name: "from response OK",
r: ResponderFromResponse(&http.Response{
Body: NewRespBodyFromString("BODY"),
StatusCode: 200,
ContentLength: -1,
}),
expLen: 4,
},
{
name: "custom without Len",
r: func(req *http.Request) (*http.Response, error) {
return &http.Response{
Body: ioutil.NopCloser(strings.NewReader("BODY")),
StatusCode: 200,
ContentLength: -1,
}, nil
},
expLen: 4,
},
}
for _, tc := range testCases {
assert.Run(tc.name, func(assert *td.T) {
sclr := tc.r.SetContentLength()

for i := 1; i <= 3; i++ {
assert.RunAssertRequire(fmt.Sprintf("#%d", i), func(assert, require *td.T) {
resp, err := sclr(req)
require.CmpNoError(err)
assert.CmpLax(resp.ContentLength, tc.expLen)
assert.Cmp(resp.Header.Get("Content-Length"), strconv.Itoa(tc.expLen))
})
}
})
}

assert.Run("error", func(assert *td.T) {
resp, err := NewErrorResponder(errors.New("an error occurred")).
SetContentLength()(req)
assert.Nil(resp)
assert.String(err, "an error occurred")
})
}

func TestResponder_HeaderAddSet(t *testing.T) {
assert, require := td.AssertRequire(t)

req, err := http.NewRequest(http.MethodGet, "http://foo.bar", nil)
require.CmpNoError(err)

orig := NewStringResponder(200, "body")
origNilHeader := ResponderFromResponse(&http.Response{
Status: "200",
StatusCode: 200,
Body: NewRespBodyFromString("body"),
ContentLength: -1,
})

// until go1.17, http.Header cannot contain nil values after a Header.Clone()
clonedNil := http.Header{"Nil": nil}.Clone()["Nil"]

testCases := []struct {
name string
orig Responder
}{
{name: "orig", orig: orig},
{name: "nil header", orig: origNilHeader},
}
assert.RunAssertRequire("HeaderAdd", func(assert, require *td.T) {
for _, tc := range testCases {
assert.RunAssertRequire(tc.name, func(assert, require *td.T) {
r := tc.orig.HeaderAdd(http.Header{"foo": {"bar"}, "nil": nil})
resp, err := r(req)
require.CmpNoError(err)
assert.Cmp(resp.Header, http.Header{"Foo": {"bar"}, "Nil": nil})

r = r.HeaderAdd(http.Header{"foo": {"zip"}, "test": {"pipo"}})
resp, err = r(req)
require.CmpNoError(err)
assert.Cmp(resp.Header, http.Header{"Foo": {"bar", "zip"}, "Test": {"pipo"}, "Nil": clonedNil})
})
}

resp, err := orig(req)
require.CmpNoError(err)
assert.Empty(resp.Header)
})

assert.RunAssertRequire("HeaderSet", func(assert, require *td.T) {
for _, tc := range testCases {
assert.RunAssertRequire(tc.name, func(assert, require *td.T) {
r := tc.orig.HeaderSet(http.Header{"foo": {"bar"}, "nil": nil})
resp, err := r(req)
require.CmpNoError(err)
assert.Cmp(resp.Header, http.Header{"Foo": {"bar"}, "Nil": nil})

r = r.HeaderSet(http.Header{"foo": {"zip"}, "test": {"pipo"}})
resp, err = r(req)
require.CmpNoError(err)
assert.Cmp(resp.Header, http.Header{"Foo": {"zip"}, "Test": {"pipo"}, "Nil": clonedNil})
})
}

resp, err := orig(req)
require.CmpNoError(err)
assert.Empty(resp.Header)
})

assert.Run("error", func(assert *td.T) {
origErr := NewErrorResponder(errors.New("an error occurred"))

assert.Run("HeaderAdd", func(assert *td.T) {
r := origErr.HeaderAdd(http.Header{"foo": {"bar"}})
resp, err := r(req)
assert.Nil(resp)
assert.String(err, "an error occurred")
})

assert.Run("HeaderSet", func(assert *td.T) {
r := origErr.HeaderSet(http.Header{"foo": {"bar"}})
resp, err := r(req)
assert.Nil(resp)
assert.String(err, "an error occurred")
})
})
}

func TestParallelResponder(t *testing.T) {
req, err := http.NewRequest(http.MethodGet, "http://foo.bar", nil)
td.Require(t).CmpNoError(err)
Expand Down

0 comments on commit d4ab20c

Please sign in to comment.