From ec25d3daa396b31ff93528a523627c1be146518b Mon Sep 17 00:00:00 2001 From: Karl Date: Tue, 17 Dec 2024 20:10:30 +0100 Subject: [PATCH] feat: handle empty get body (#3) --- transport.go | 32 +++++++++++++++++-- transport_test.go | 79 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 108 insertions(+), 3 deletions(-) diff --git a/transport.go b/transport.go index c2321e0..c71a1a7 100644 --- a/transport.go +++ b/transport.go @@ -1,6 +1,7 @@ package httpr import ( + "bytes" "io" "net/http" "time" @@ -51,9 +52,8 @@ func New(options ...Option) *Transport { // 408, 429, 500, 502, 503 and 504. var NewTransport = New -// RoundTrip satisfies the http.RoundTripper interface and performs an -// http request with the configured retry policy. -func (tr *Transport) RoundTrip(r *http.Request) (*http.Response, error) { +// setup sets the default values for the transport when they are not provided. +func (tr *Transport) setup() { if tr.tr == nil { tr.tr = http.DefaultTransport } @@ -70,6 +70,17 @@ func (tr *Transport) RoundTrip(r *http.Request) (*http.Response, error) { return 0 } } +} + +// RoundTrip satisfies the http.RoundTripper interface and performs an +// http request with the configured retry policy. +func (tr *Transport) RoundTrip(r *http.Request) (*http.Response, error) { + tr.setup() + if r.Body != nil && r.GetBody == nil { + if err := setGetBody(r); err != nil { + return nil, err + } + } backoff := tr.rp.Backoff retries := 0 @@ -127,3 +138,18 @@ func drainResponse(r *http.Response) error { } return nil } + +// setGetBody sets the GetBody method on the request +func setGetBody(r *http.Request) error { + body, err := io.ReadAll(r.Body) + if err != nil { + return err + } + r.Body = io.NopCloser(bytes.NewReader(body)) + r.ContentLength = int64(len(body)) + + r.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(body)), nil + } + return nil +} diff --git a/transport_test.go b/transport_test.go index 20f1a1f..0e13fe9 100644 --- a/transport_test.go +++ b/transport_test.go @@ -9,12 +9,64 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "testing" "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" ) +func TestNew(t *testing.T) { + var tests = []struct { + name string + input []Option + want *Transport + }{ + { + name: "default transport", + want: &Transport{ + tr: http.DefaultTransport, + rp: defaultRetryPolicy(), + }, + }, + { + name: "with retry policy", + input: []Option{ + WithRetryPolicy(RetryPolicy{ + ShouldRetry: StandardShouldRetry, + Backoff: ExponentialBackoff(), + MaxRetries: 5, + MinDelay: 1 * time.Second, + MaxDelay: 10 * time.Second, + Jitter: 0.1, + }), + }, + want: &Transport{ + tr: http.DefaultTransport, + rp: RetryPolicy{ + ShouldRetry: StandardShouldRetry, + Backoff: ExponentialBackoff(), + MaxRetries: 5, + MinDelay: 1 * time.Second, + MaxDelay: 10 * time.Second, + Jitter: 0.1, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got := New(test.input...) + + if diff := cmp.Diff(test.want, got, cmp.AllowUnexported(Transport{}, http.Transport{}), cmpopts.IgnoreUnexported(http.Transport{}), cmpopts.IgnoreFields(RetryPolicy{}, "ShouldRetry", "Backoff"), cmpopts.IgnoreFields(http.Transport{}, "Proxy", "DialContext")); diff != "" { + t.Errorf("New() = unexpected result (-want +got)\n%s\n", diff) + } + }) + } +} + func TestTransport_RoundTrip(t *testing.T) { type input struct { req func() *http.Request @@ -165,6 +217,33 @@ func TestTransport_RoundTrip(t *testing.T) { body: wantBodyPost, }, }, + { + name: "successful POST with retries - request literal", + input: input{ + req: func() *http.Request { + u, _ := url.Parse("http://example.com") + req := &http.Request{ + Method: http.MethodPost, + URL: u, + Body: io.NopCloser(bytes.NewReader(wantBodyPost)), + } + return req + }, + retryPolicy: RetryPolicy{ + ShouldRetry: StandardShouldRetry, + Backoff: ExponentialBackoff(), + MaxRetries: 3, + MinDelay: 1 * time.Millisecond, + MaxDelay: 5 * time.Millisecond, + }, + retries: 3, + err: errors.New("error"), + }, + want: want{ + statusCode: http.StatusOK, + body: wantBodyPost, + }, + }, } for _, test := range tests {