Skip to content

Commit

Permalink
[Fix] Avoid loading the response bodies twice in memory when parsing …
Browse files Browse the repository at this point in the history
…`bytes.Buffer` (#984)

## Changes

Prior to this PR, `WithResponseUnmarshal` was unmarshalling a response
body to a `bytes.Buffer` by:

1. Reading all the response body in a slice of bytes; and
2. Writing that slice in the receiving `bytes.Buffer`.

Step 1 unnecessarily increases the memory footprint as the response body
is essentially stored twice in memory during unmarshalling. This PR
fixes that problem by loading the body directly in the buffer without
passing by an intermediate container.

## Tests

Added a few unit test to cover all unmarshalling types.

Memory benchmark:

```go
var longString = strings.Repeat("a", 1<<30) 

func BenchmarkWithResponseUnmarshal_byteBuffer(b *testing.B) {
	b.ReportAllocs()

	client := NewApiClient(ClientConfig{
		Transport: hc(func(r *http.Request) (*http.Response, error) {
			return &http.Response{
				StatusCode: 200,
				Body:       io.NopCloser(strings.NewReader(longString)),
				Request:    r,
			}, nil
		}),
	})

	for i := 0; i < b.N; i++ {
		var got bytes.Buffer
		client.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))
	}
}
```

Results:

```
Before:	6810268826 B/op	140 allocs/op
After:	4294972868 B/op	101 allocs/op
```

- [x] `make test` passing
- [x] `make fmt` applied
- [x] relevant integration tests applied
  • Loading branch information
renaudhartert-db committed Jul 17, 2024
1 parent b55a992 commit b71c73d
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 38 deletions.
51 changes: 31 additions & 20 deletions httpclient/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,17 @@ func WithResponseHeader(key string, value *string) DoOption {
}
}

// WithResponseUnmarshal unmarshals the response body into response. The
// supported response types are the following:
// - *bytes.Buffer,
// - *io.ReadCloser,
// - *[]byte,
// - a pointer to a struct with a Contents io.ReadCloser field,
// - a pointer to a struct representing a JSON object.
//
// If response is a pointer to a io.ReadCloser or a struct with a io.ReadCloser
// field name "Contents", then the response io.ReadCloser is set to the value of
// the body's reader without actually reading it.
func WithResponseUnmarshal(response any) DoOption {
return DoOption{
in: func(r *http.Request) error {
Expand All @@ -50,45 +61,45 @@ func WithResponseUnmarshal(response any) DoOption {
if err != nil {
return err
}
// If the field contains a "Content" field of type bytes.Buffer, write the body over there and return.
if field, ok := findContentsField(response, body); ok {
// If so, set the value

if field, ok := findContentsField(response); ok {
field.Set(reflect.ValueOf(body.ReadCloser))
return nil
}

// If the destination is bytes.Buffer, write the body over there
if raw, ok := response.(*io.ReadCloser); ok {
*raw = body.ReadCloser
if reader, ok := response.(*io.ReadCloser); ok {
*reader = body.ReadCloser
return nil
}
if buffer, ok := response.(*bytes.Buffer); ok {
defer body.ReadCloser.Close()
_, err := buffer.ReadFrom(body.ReadCloser)
return err
}

// At this point, fully read the content of the body and use it
// to populate the response object (whether it is a slice of bytes
// or a JSON object).
defer body.ReadCloser.Close()
bs, err := io.ReadAll(body.ReadCloser)
bodyBytes, err := io.ReadAll(body.ReadCloser)
if err != nil {
return fmt.Errorf("failed to read response body: %w", err)
}
if len(bs) == 0 {
if len(bodyBytes) == 0 {
return nil
}
// If the destination is a byte slice or buffer, pass the body verbatim.
if raw, ok := response.(*[]byte); ok {
*raw = bs
if bs, ok := response.(*[]byte); ok {
*bs = bodyBytes
return nil
}
if raw, ok := response.(*bytes.Buffer); ok {
_, err := raw.Write(bs)
return err
}
err = json.Unmarshal(bs, &response)
if err != nil {
return apierr.MakeUnexpectedError(body.Response, err, body.RequestBody.DebugBytes, bs)
if err = json.Unmarshal(bodyBytes, &response); err != nil {
return apierr.MakeUnexpectedError(body.Response, err, body.RequestBody.DebugBytes, bodyBytes)
}
return nil
},
}
}

func findContentsField(response any, body *common.ResponseWrapper) (*reflect.Value, bool) {
func findContentsField(response any) (*reflect.Value, bool) {
value := reflect.ValueOf(response)
value = reflect.Indirect(value)
if value.Kind() != reflect.Struct {
Expand Down
103 changes: 85 additions & 18 deletions httpclient/response_test.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,97 @@
package httpclient

import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"

"github.com/stretchr/testify/require"
"github.com/stretchr/testify/assert"
)

func TestSimpleRequestRawResponse(t *testing.T) {
c := NewApiClient(ClientConfig{
func make200Response(body string) *http.Response {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader(body)),
}
}

func mockClient(resp *http.Response) *ApiClient {
return NewApiClient(ClientConfig{
Transport: hc(func(r *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: io.NopCloser(strings.NewReader("Hello, world!")),
Request: r,
}, nil
resp.Request = r
return resp, nil
}),
})
var raw []byte
err := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&raw))
require.NoError(t, err)
require.Equal(t, "Hello, world!", string(raw))
}

func TestWithResponseUnmarshal_structWithContent(t *testing.T) {
c := mockClient(make200Response("foo bar"))
type structWithContents = struct {
Contents io.ReadCloser
}
want := structWithContents{
Contents: io.NopCloser(strings.NewReader("foo bar")),
}

var got structWithContents
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

assert.NoError(t, gotErr)
wantBytes, _ := io.ReadAll(want.Contents)
gotBytes, _ := io.ReadAll(got.Contents)
assert.Equal(t, wantBytes, gotBytes)
}

func TestWithResponseUnmarshal_readCloser(t *testing.T) {
c := mockClient(make200Response("foo bar"))
want := io.NopCloser(strings.NewReader("foo bar"))

var got io.ReadCloser
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

assert.NoError(t, gotErr)
wantBytes, _ := io.ReadAll(want)
gotBytes, _ := io.ReadAll(got)
assert.Equal(t, wantBytes, gotBytes)
}

func TestWithResponseUnmarshal_byteBuffer(t *testing.T) {
c := mockClient(make200Response("foo bar"))
want := bytes.NewBuffer([]byte("foo bar"))

var got bytes.Buffer
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

assert.NoError(t, gotErr)
assert.Equal(t, want.Bytes(), got.Bytes())
}

func TestWithResponseUnmarshal_bytes(t *testing.T) {
c := mockClient(make200Response("foo bar"))
want := []byte("foo bar")

var got []byte
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

assert.NoError(t, gotErr)
assert.Equal(t, want, got)
}

func TestWithResponseUnmarshal_json(t *testing.T) {
c := mockClient(make200Response(`{"foo": "bar"}`))
type jsonStruct struct {
Foo string `json:"foo"`
}
want := jsonStruct{Foo: "bar"}

var got jsonStruct
gotErr := c.Do(context.Background(), "GET", "/a", WithResponseUnmarshal(&got))

assert.NoError(t, gotErr)
assert.Equal(t, want, got)
}

func TestWithResponseHeader(t *testing.T) {
Expand All @@ -40,10 +108,9 @@ func TestWithResponseHeader(t *testing.T) {
}),
})

var out string
ctx := context.Background()
err := client.Do(ctx, "GET", "abc",
WithResponseHeader("Foo", &out))
require.NoError(t, err)
require.Equal(t, "some", out)
var got string
gotErr := client.Do(context.Background(), "GET", "abc", WithResponseHeader("Foo", &got))

assert.NoError(t, gotErr)
assert.Equal(t, "some", got)
}

0 comments on commit b71c73d

Please sign in to comment.