From db50b7669db74ba622c08cb8be04231521d194a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maxime=20Soul=C3=A9?= Date: Mon, 16 Jan 2023 22:17:13 +0100 Subject: [PATCH 1/2] perf(matchers): simplify http.NoBody case MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Maxime Soulé --- match.go | 12 ++++++-- match_test.go | 85 ++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/match.go b/match.go index c91d337..baf46ee 100644 --- a/match.go +++ b/match.go @@ -466,8 +466,8 @@ func (m matchRouteKey) String() string { } // bodyCopyOnRead copies body content to buf on first Read(), except -// if body is nil. In this case, EOF is returned for each Read() and -// buf stays to nil. +// if body is nil or http.NoBody. In this case, EOF is returned for +// each Read() and buf stays to nil. type bodyCopyOnRead struct { body io.ReadCloser buf []byte @@ -480,7 +480,7 @@ func (b *bodyCopyOnRead) rearm() { } func (b *bodyCopyOnRead) copy() { - if b.buf == nil && b.body != nil { + if b.buf == nil && b.body != nil && b.body != http.NoBody { var body bytes.Buffer io.Copy(&body, b.body) //nolint: errcheck b.body.Close() @@ -500,3 +500,9 @@ func (b *bodyCopyOnRead) Read(p []byte) (n int, err error) { func (b *bodyCopyOnRead) Close() error { return nil } + +// Len returns the buffer total length, whatever the Read position in body is. +func (b *bodyCopyOnRead) Len() int { + b.copy() + return len(b.buf) +} diff --git a/match_test.go b/match_test.go index 57cabb7..2984d79 100644 --- a/match_test.go +++ b/match_test.go @@ -418,28 +418,71 @@ func TestBodyCopyOnRead(t *testing.T) { td.CmpNoError(t, bc.Close()) }) - t.Run("nil body", func(t *testing.T) { - bc := httpmock.NewBodyCopyOnRead(nil) - - bc.Rearm() - td.CmpNil(t, bc.Buf()) - - var buf [4]byte - n, err := bc.Read(buf[:]) - td.Cmp(t, err, io.EOF) - td.Cmp(t, n, 0) - td.CmpNil(t, bc.Buf()) - td.Cmp(t, bc.Body(), nil) - - bc.Rearm() - - n, err = bc.Read(buf[:]) - td.Cmp(t, err, io.EOF) - td.Cmp(t, n, 0) - td.CmpNil(t, bc.Buf()) - td.Cmp(t, bc.Body(), nil) + testCases := []struct { + name string + body io.ReadCloser + }{ + { + name: "nil body", + }, + { + name: "no body", + body: http.NoBody, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + bc := httpmock.NewBodyCopyOnRead(tc.body) + + bc.Rearm() + td.CmpNil(t, bc.Buf()) + + var buf [4]byte + n, err := bc.Read(buf[:]) + td.Cmp(t, err, io.EOF) + td.Cmp(t, n, 0) + td.CmpNil(t, bc.Buf()) + td.Cmp(t, bc.Body(), tc.body) + + bc.Rearm() + + n, err = bc.Read(buf[:]) + td.Cmp(t, err, io.EOF) + td.Cmp(t, n, 0) + td.CmpNil(t, bc.Buf()) + td.Cmp(t, bc.Body(), tc.body) + + td.CmpNoError(t, bc.Close()) + }) + } - td.CmpNoError(t, bc.Close()) + t.Run("len", func(t *testing.T) { + testCases := []struct { + name string + bc interface{ Len() int } + expected int + }{ + { + name: "nil", + bc: httpmock.NewBodyCopyOnRead(nil), + expected: 0, + }, + { + name: "no body", + bc: httpmock.NewBodyCopyOnRead(http.NoBody), + expected: 0, + }, + { + name: "filled", + bc: httpmock.NewBodyCopyOnRead(ioutil.NopCloser(bytes.NewReader([]byte(`BODY`)))), + expected: 4, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + td.Cmp(t, tc.bc.Len(), tc.expected) + }) + } }) } From d4ab20c68d8b0c8ac08908ef5dc7bbc80d9b462f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Maxime=20Soul=C3=A9?= Date: Mon, 16 Jan 2023 22:22:47 +0100 Subject: [PATCH 2/2] feat: add Responder.{SetContentLength,HeaderAdd,HeaderSet} methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Maxime Soulé --- response.go | 164 +++++++++++++++++++++++++++++++++++++++++++++- response_test.go | 166 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 329 insertions(+), 1 deletion(-) diff --git a/response.go b/response.go index b6274e7..fd9ecd4 100644 --- a/response.go +++ b/response.go @@ -217,6 +217,153 @@ func (r Responder) Then(next Responder) (x Responder) { return } +// SetContentLength returns a new [Responder] based on r that ensures +// the returned [*http.Response] ContentLength field and +// Content-Length header are set to the right value. +// +// If r returns an [*http.Response] with a nil Body or equal to +// [http.NoBody], the length is always set to 0. +// +// If r returned response.Body implements: +// +// Len() int +// +// then the length is set to the Body.Len() returned value. All +// httpmock generated bodies implement this method. Beware that +// [strings.Builder], [strings.Reader], [bytes.Buffer] and +// [bytes.Reader] types used with [io.NopCloser] do not implement +// Len() anymore. +// +// Otherwise, r returned response.Body is entirely copied into an +// internal buffer to get its length, then it is closed. The Body of +// the [*http.Response] returned by the [Responder] returned by +// SetContentLength can then be read again to return its content as +// usual. But keep in mind that each time this [Responder] is called, +// r is called first. So this one has to carefully handle its body: it +// is highly recommended to use [NewRespBodyFromString] or +// [NewRespBodyFromBytes] to set the body once (as +// [NewStringResponder] and [NewBytesResponder] do behind the scene), +// or to build the body each time r is called. +// +// The following calls are all correct: +// +// responder = httpmock.NewStringResponder(200, "BODY").SetContentLength() +// responder = httpmock.NewBytesResponder(200, []byte("BODY")).SetContentLength() +// responder = ResponderFromResponse(&http.Response{ +// // build a body once, but httpmock knows how to "rearm" it once read +// Body: NewRespBodyFromString("BODY"), +// StatusCode: 200, +// }).SetContentLength() +// responder = httpmock.Responder(func(req *http.Request) (*http.Response, error) { +// // build a new body for each call +// return &http.Response{ +// StatusCode: 200, +// Body: io.NopCloser(strings.NewReader("BODY")), +// }, nil +// }).SetContentLength() +// +// But the following is not correct: +// +// responder = httpmock.ResponderFromResponse(&http.Response{ +// StatusCode: 200, +// Body: io.NopCloser(strings.NewReader("BODY")), +// }).SetContentLength() +// +// it will only succeed for the first responder call. The following +// calls will deliver responses with an empty body, as it will already +// been read by the first call. +func (r Responder) SetContentLength() Responder { + return func(req *http.Request) (*http.Response, error) { + resp, err := r(req) + if err != nil { + return nil, err + } + nr := *resp + switch nr.Body { + case nil: + nr.Body = http.NoBody + fallthrough + case http.NoBody: + nr.ContentLength = 0 + default: + bl, ok := nr.Body.(interface{ Len() int }) + if !ok { + copyBody := &dummyReadCloser{orig: nr.Body} + bl, nr.Body = copyBody, copyBody + } + nr.ContentLength = int64(bl.Len()) + } + if nr.Header == nil { + nr.Header = http.Header{} + } + nr.Header = nr.Header.Clone() + nr.Header.Set("Content-Length", strconv.FormatInt(nr.ContentLength, 10)) + return &nr, nil + } +} + +// HeaderAdd returns a new [Responder] based on r that ensures the +// returned [*http.Response] includes h header. It adds each h entry +// to the header. It appends to any existing values associated with +// each h key. Each key is case insensitive; it is canonicalized by +// [http.CanonicalHeaderKey]. +// +// See also [Responder.HeaderSet] and [Responder.SetContentLength]. +func (r Responder) HeaderAdd(h http.Header) Responder { + return func(req *http.Request) (*http.Response, error) { + resp, err := r(req) + if err != nil { + return nil, err + } + nr := *resp + if nr.Header == nil { + nr.Header = make(http.Header, len(h)) + } + nr.Header = nr.Header.Clone() + for k, v := range h { + k = http.CanonicalHeaderKey(k) + if v == nil { + if _, ok := nr.Header[k]; !ok { + nr.Header[k] = nil + } + continue + } + nr.Header[k] = append(nr.Header[k], v...) + } + return &nr, nil + } +} + +// HeaderSet returns a new [Responder] based on r that ensures the +// returned [*http.Response] includes h header. It sets the header +// entries associated with each h key. It replaces any existing values +// associated each h key. Each key is case insensitive; it is +// canonicalized by [http.CanonicalHeaderKey]. +// +// See also [Responder.HeaderAdd] and [Responder.SetContentLength]. +func (r Responder) HeaderSet(h http.Header) Responder { + return func(req *http.Request) (*http.Response, error) { + resp, err := r(req) + if err != nil { + return nil, err + } + nr := *resp + if nr.Header == nil { + nr.Header = make(http.Header, len(h)) + } + nr.Header = nr.Header.Clone() + for k, v := range h { + k = http.CanonicalHeaderKey(k) + if v == nil { + nr.Header[k] = nil + continue + } + nr.Header[k] = append([]string(nil), v...) + } + return &nr, nil + } +} + // ResponderFromResponse wraps an [*http.Response] in a [Responder]. // // Be careful, except for responses generated by httpmock @@ -560,9 +707,14 @@ func NewRespBodyFromBytes(body []byte) io.ReadCloser { return &dummyReadCloser{orig: body} } +type lenReadSeeker interface { + io.ReadSeeker + Len() int +} + type dummyReadCloser struct { orig any // string or []byte - body io.ReadSeeker // instanciated on demand from orig + body lenReadSeeker // instanciated on demand from orig } // copy returns a new instance resetting d.body to nil. @@ -578,6 +730,11 @@ func (d *dummyReadCloser) setup() { d.body = strings.NewReader(body) case []byte: d.body = bytes.NewReader(body) + case io.ReadCloser: + var buf bytes.Buffer + io.Copy(&buf, body) //nolint: errcheck + body.Close() + d.body = bytes.NewReader(buf.Bytes()) } } } @@ -592,3 +749,8 @@ func (d *dummyReadCloser) Close() error { d.body.Seek(0, io.SeekEnd) // nolint: errcheck return nil } + +func (d *dummyReadCloser) Len() int { + d.setup() + return d.body.Len() +} diff --git a/response_test.go b/response_test.go index 496deff..d491685 100644 --- a/response_test.go +++ b/response_test.go @@ -8,6 +8,7 @@ import ( "io/ioutil" //nolint: staticcheck "net/http" "path/filepath" + "strconv" "strings" "sync" "testing" @@ -586,6 +587,171 @@ func TestResponder_Then(t *testing.T) { }) } +func TestResponder_SetContentLength(t *testing.T) { + assert, require := td.AssertRequire(t) + + req, err := http.NewRequest(http.MethodGet, "http://foo.bar", nil) + require.CmpNoError(err) + + testCases := []struct { + name string + r Responder + expLen int + }{ + { + name: "nil body", + r: ResponderFromResponse(&http.Response{ + StatusCode: 200, + ContentLength: -1, + }), + expLen: 0, + }, + { + name: "http.NoBody", + r: ResponderFromResponse(&http.Response{ + Body: http.NoBody, + StatusCode: 200, + ContentLength: -1, + }), + expLen: 0, + }, + { + name: "string", + r: NewStringResponder(200, "BODY"), + expLen: 4, + }, + { + name: "bytes", + r: NewBytesResponder(200, []byte("BODY")), + expLen: 4, + }, + { + name: "from response OK", + r: ResponderFromResponse(&http.Response{ + Body: NewRespBodyFromString("BODY"), + StatusCode: 200, + ContentLength: -1, + }), + expLen: 4, + }, + { + name: "custom without Len", + r: func(req *http.Request) (*http.Response, error) { + return &http.Response{ + Body: ioutil.NopCloser(strings.NewReader("BODY")), + StatusCode: 200, + ContentLength: -1, + }, nil + }, + expLen: 4, + }, + } + for _, tc := range testCases { + assert.Run(tc.name, func(assert *td.T) { + sclr := tc.r.SetContentLength() + + for i := 1; i <= 3; i++ { + assert.RunAssertRequire(fmt.Sprintf("#%d", i), func(assert, require *td.T) { + resp, err := sclr(req) + require.CmpNoError(err) + assert.CmpLax(resp.ContentLength, tc.expLen) + assert.Cmp(resp.Header.Get("Content-Length"), strconv.Itoa(tc.expLen)) + }) + } + }) + } + + assert.Run("error", func(assert *td.T) { + resp, err := NewErrorResponder(errors.New("an error occurred")). + SetContentLength()(req) + assert.Nil(resp) + assert.String(err, "an error occurred") + }) +} + +func TestResponder_HeaderAddSet(t *testing.T) { + assert, require := td.AssertRequire(t) + + req, err := http.NewRequest(http.MethodGet, "http://foo.bar", nil) + require.CmpNoError(err) + + orig := NewStringResponder(200, "body") + origNilHeader := ResponderFromResponse(&http.Response{ + Status: "200", + StatusCode: 200, + Body: NewRespBodyFromString("body"), + ContentLength: -1, + }) + + // until go1.17, http.Header cannot contain nil values after a Header.Clone() + clonedNil := http.Header{"Nil": nil}.Clone()["Nil"] + + testCases := []struct { + name string + orig Responder + }{ + {name: "orig", orig: orig}, + {name: "nil header", orig: origNilHeader}, + } + assert.RunAssertRequire("HeaderAdd", func(assert, require *td.T) { + for _, tc := range testCases { + assert.RunAssertRequire(tc.name, func(assert, require *td.T) { + r := tc.orig.HeaderAdd(http.Header{"foo": {"bar"}, "nil": nil}) + resp, err := r(req) + require.CmpNoError(err) + assert.Cmp(resp.Header, http.Header{"Foo": {"bar"}, "Nil": nil}) + + r = r.HeaderAdd(http.Header{"foo": {"zip"}, "test": {"pipo"}}) + resp, err = r(req) + require.CmpNoError(err) + assert.Cmp(resp.Header, http.Header{"Foo": {"bar", "zip"}, "Test": {"pipo"}, "Nil": clonedNil}) + }) + } + + resp, err := orig(req) + require.CmpNoError(err) + assert.Empty(resp.Header) + }) + + assert.RunAssertRequire("HeaderSet", func(assert, require *td.T) { + for _, tc := range testCases { + assert.RunAssertRequire(tc.name, func(assert, require *td.T) { + r := tc.orig.HeaderSet(http.Header{"foo": {"bar"}, "nil": nil}) + resp, err := r(req) + require.CmpNoError(err) + assert.Cmp(resp.Header, http.Header{"Foo": {"bar"}, "Nil": nil}) + + r = r.HeaderSet(http.Header{"foo": {"zip"}, "test": {"pipo"}}) + resp, err = r(req) + require.CmpNoError(err) + assert.Cmp(resp.Header, http.Header{"Foo": {"zip"}, "Test": {"pipo"}, "Nil": clonedNil}) + }) + } + + resp, err := orig(req) + require.CmpNoError(err) + assert.Empty(resp.Header) + }) + + assert.Run("error", func(assert *td.T) { + origErr := NewErrorResponder(errors.New("an error occurred")) + + assert.Run("HeaderAdd", func(assert *td.T) { + r := origErr.HeaderAdd(http.Header{"foo": {"bar"}}) + resp, err := r(req) + assert.Nil(resp) + assert.String(err, "an error occurred") + }) + + assert.Run("HeaderSet", func(assert *td.T) { + r := origErr.HeaderSet(http.Header{"foo": {"bar"}}) + resp, err := r(req) + assert.Nil(resp) + assert.String(err, "an error occurred") + }) + }) +} + func TestParallelResponder(t *testing.T) { req, err := http.NewRequest(http.MethodGet, "http://foo.bar", nil) td.Require(t).CmpNoError(err)