From f37dd841f32683a4397ec1c7dd17504517a48c2e Mon Sep 17 00:00:00 2001 From: Calvin Lobo Date: Thu, 14 Sep 2023 13:27:55 -0400 Subject: [PATCH] Added a RoundTripper to simplehttp that wraps around another RoundTripper and converts 4XX and 5XX series errors to SimpleError with the corresponding code set, as defined in the inverse mapping --- ecosystem/http/roundtripper.go | 55 ++++++++++++++++++++ ecosystem/http/roundtripper_test.go | 79 +++++++++++++++++++++++++++++ 2 files changed, 134 insertions(+) create mode 100644 ecosystem/http/roundtripper.go create mode 100644 ecosystem/http/roundtripper_test.go diff --git a/ecosystem/http/roundtripper.go b/ecosystem/http/roundtripper.go new file mode 100644 index 0000000..7eb4dce --- /dev/null +++ b/ecosystem/http/roundtripper.go @@ -0,0 +1,55 @@ +package simplehttp + +import ( + "github.com/lobocv/simplerr" + "net/http" +) + +type attr int + +const ( + attrHTTPResponse = attr(1) +) + +// roundTripper is a wrapper around the given http.RoundTripper that converts 4XX and 5XX series errors to SimpleErrors +type roundTripper struct { + rt http.RoundTripper +} + +func (s roundTripper) RoundTrip(request *http.Request) (*http.Response, error) { + resp, err := s.rt.RoundTrip(request) + if err != nil { + return resp, err + } + + if resp.StatusCode >= 400 && resp.StatusCode < 600 { + code, _ := GetCode(resp.StatusCode) + serr := simplerr.New("%s", resp.Status). + Code(code). + Attr(attrHTTPResponse, resp) + return nil, serr + } + + return resp, nil +} + +// EnableHTTPStatusErrors wraps the http.RoundTripper in middleware that converts 4XX and 5XX series errors to SimpleErrors +// with the code defined in the inverse mapping. +func EnableHTTPStatusErrors(rt http.RoundTripper) http.RoundTripper { + return roundTripper{rt: rt} +} + +// GetHTTPResponseAttr gets the *http.Response attached to the error, if it exists. +func GetHTTPResponseAttr(err error) *http.Response { + v, ok := simplerr.GetAttribute(err, attrHTTPResponse) + if !ok { + return nil + } + + resp, ok := v.(*http.Response) + if !ok { + return nil + } + + return resp +} diff --git a/ecosystem/http/roundtripper_test.go b/ecosystem/http/roundtripper_test.go new file mode 100644 index 0000000..257c74d --- /dev/null +++ b/ecosystem/http/roundtripper_test.go @@ -0,0 +1,79 @@ +package simplehttp + +import ( + "fmt" + "github.com/lobocv/simplerr" + "github.com/stretchr/testify/require" + "net/http" + "testing" +) + +type dummyTransport struct { + response *http.Response + err error +} + +func (d dummyTransport) RoundTrip(_ *http.Request) (*http.Response, error) { + return d.response, d.err +} + +func TestRoundTripperConvertCode(t *testing.T) { + + for status, expectedCode := range inverseMapping { + originalResponse := &http.Response{} + originalResponse.StatusCode = status + + rt := EnableHTTPStatusErrors(dummyTransport{ + response: originalResponse, + }) + resp, err := rt.RoundTrip(nil) + + require.NotNil(t, simplerr.As(err), fmt.Sprintf("'%s' failed: simplerr not returned", http.StatusText(status))) + require.True(t, simplerr.HasErrorCode(err, expectedCode), fmt.Sprintf("'%s' failed: unexpected code", http.StatusText(status))) + require.Nil(t, resp) + } +} + +func TestRoundTripperNoConversionFound(t *testing.T) { + originalResponse := &http.Response{} + originalResponse.StatusCode = http.StatusOK + + rt := EnableHTTPStatusErrors(dummyTransport{ + response: originalResponse, + }) + resp, err := rt.RoundTrip(nil) + require.NoError(t, err) + require.Equal(t, originalResponse, resp) +} + +func TestRoundTripperAttrResponse(t *testing.T) { + + originalResponse := &http.Response{StatusCode: http.StatusTeapot} + + rt := EnableHTTPStatusErrors(dummyTransport{ + response: originalResponse, + }) + + resp, err := rt.RoundTrip(nil) + expectedCode := simplerr.CodeUnknown + require.True(t, simplerr.HasErrorCode(err, expectedCode), "failed to convert error code") + require.Nil(t, resp) + + gotOriginalResponse := GetHTTPResponseAttr(err) + require.Equal(t, originalResponse, gotOriginalResponse) + + require.Nil(t, GetHTTPResponseAttr(nil)) + require.Nil(t, GetHTTPResponseAttr(simplerr.New("something").Attr(attrHTTPResponse, "not an *http.Response"))) +} + +func TestRoundTripperErrorOnUnderlyingRoundTripper(t *testing.T) { + + rt := EnableHTTPStatusErrors(dummyTransport{ + err: fmt.Errorf("some error"), + }) + resp, err := rt.RoundTrip(nil) + + require.Nil(t, resp) + require.Errorf(t, err, "some error") + +}