diff --git a/ecosystem/http/translate_error_code.go b/ecosystem/http/translate_error_code.go index 0d71215..638e58b 100644 --- a/ecosystem/http/translate_error_code.go +++ b/ecosystem/http/translate_error_code.go @@ -2,6 +2,7 @@ package simplehttp import ( "net/http" + "sync" "github.com/lobocv/simplerr" ) @@ -10,9 +11,13 @@ import ( type HTTPStatus = int var mapping map[simplerr.Code]HTTPStatus +var inverseMapping map[HTTPStatus]simplerr.Code + var simplerrCodes []simplerr.Code var defaultErrorStatus = http.StatusInternalServerError +var lock = sync.Mutex{} + // DefaultMapping returns the default mapping of SimpleError codes to HTTP status codes func DefaultMapping() map[simplerr.Code]HTTPStatus { var m = map[simplerr.Code]HTTPStatus{ @@ -22,15 +27,37 @@ func DefaultMapping() map[simplerr.Code]HTTPStatus { simplerr.CodePermissionDenied: http.StatusForbidden, simplerr.CodeUnauthenticated: http.StatusUnauthorized, simplerr.CodeNotImplemented: http.StatusNotImplemented, + simplerr.CodeMalformedRequest: http.StatusBadRequest, simplerr.CodeInvalidArgument: http.StatusBadRequest, + simplerr.CodeMissingParameter: http.StatusBadRequest, simplerr.CodeResourceExhausted: http.StatusTooManyRequests, } return m } +// DefaultInverseMapping returns the default mapping of HTTP status codes to SimpleError code +func DefaultInverseMapping() map[HTTPStatus]simplerr.Code { + var m = map[HTTPStatus]simplerr.Code{ + http.StatusInternalServerError: simplerr.CodeUnknown, + http.StatusNotFound: simplerr.CodeNotFound, + http.StatusRequestTimeout: simplerr.CodeDeadlineExceeded, + http.StatusForbidden: simplerr.CodePermissionDenied, + http.StatusUnauthorized: simplerr.CodeUnauthenticated, + http.StatusNotImplemented: simplerr.CodeNotImplemented, + http.StatusBadRequest: simplerr.CodeMalformedRequest, + http.StatusServiceUnavailable: simplerr.CodeUnavailable, + http.StatusMethodNotAllowed: simplerr.CodeMalformedRequest, + http.StatusTooManyRequests: simplerr.CodeResourceExhausted, + } + return m +} + // SetMapping sets the mapping from simplerr.Code to HTTP status code func SetMapping(m map[simplerr.Code]HTTPStatus) { + lock.Lock() + defer lock.Unlock() mapping = m + simplerrCodes = []simplerr.Code{} // Get a list of simplerr codes to search for in the error chain for c := range m { @@ -42,6 +69,13 @@ func SetMapping(m map[simplerr.Code]HTTPStatus) { } } +// SetInverseMapping sets the mapping from HTTP status code to simplerr.Code +func SetInverseMapping(m map[HTTPStatus]simplerr.Code) { + lock.Lock() + defer lock.Unlock() + inverseMapping = m +} + // SetDefaultErrorStatus changes the default HTTP status code for when a translation could not be found. // The default status code is 500. func SetDefaultErrorStatus(code int) { @@ -50,6 +84,7 @@ func SetDefaultErrorStatus(code int) { func init() { SetMapping(DefaultMapping()) + SetInverseMapping(DefaultInverseMapping()) } // SetStatus sets the http.Response status from the error code in the provided error. @@ -94,3 +129,12 @@ func GetStatus(err error) (status HTTPStatus, found bool) { return httpCode, true } + +// GetCode gets the simplerror Code that corresponds to the HTTPStatus. It returns CodeUknown if it cannot map the status. +func GetCode(status HTTPStatus) (code simplerr.Code, found bool) { + code, ok := inverseMapping[status] + if !ok { + return simplerr.CodeUnknown, false + } + return code, true +} diff --git a/ecosystem/http/translate_error_code_test.go b/ecosystem/http/translate_error_code_test.go index 94c4ffa..7404866 100644 --- a/ecosystem/http/translate_error_code_test.go +++ b/ecosystem/http/translate_error_code_test.go @@ -2,16 +2,34 @@ package simplehttp import ( "fmt" + "github.com/stretchr/testify/suite" "net/http" "net/http/httptest" "testing" - "github.com/stretchr/testify/require" - "github.com/lobocv/simplerr" ) -func TestTranslateErrorCode(t *testing.T) { +type TestSuite struct { + suite.Suite +} + +func TestHTTP(t *testing.T) { + s := new(TestSuite) + + // Change the default mappings to test that they apply + m := DefaultMapping() + m[simplerr.CodeCanceled] = http.StatusRequestTimeout + SetMapping(m) + SetDefaultErrorStatus(http.StatusInternalServerError) + + invM := DefaultInverseMapping() + invM[http.StatusRequestTimeout] = simplerr.CodeCanceled + SetInverseMapping(invM) + suite.Run(t, s) +} + +func (s *TestSuite) TestTranslateErrorCode() { testCases := []struct { err error @@ -23,30 +41,52 @@ func TestTranslateErrorCode(t *testing.T) { {simplerr.New("something").Code(simplerr.CodePermissionDenied), http.StatusForbidden, true}, {simplerr.New("something").Code(simplerr.CodeCanceled), http.StatusRequestTimeout, true}, {simplerr.New("something").Code(simplerr.CodeConstraintViolated), http.StatusInternalServerError, false}, + {simplerr.New("something").Code(simplerr.CodeMalformedRequest), http.StatusBadRequest, true}, + {simplerr.New("something").Code(simplerr.CodeMissingParameter), http.StatusBadRequest, true}, {fmt.Errorf("wrapped: %w", simplerr.New("something").Code(simplerr.CodeUnauthenticated)), http.StatusUnauthorized, true}, {fmt.Errorf("opaque: %s", simplerr.New("something").Code(simplerr.CodeUnauthenticated)), http.StatusInternalServerError, false}, {simplerr.Wrap(simplerr.New("something").Code(simplerr.CodePermissionDenied)), http.StatusForbidden, true}, {nil, 200, false}, // default code for httptest.ResponseRecorder is 200 } - // Alter the default mapping - m := DefaultMapping() - m[simplerr.CodeCanceled] = http.StatusRequestTimeout - SetMapping(m) - SetDefaultErrorStatus(http.StatusInternalServerError) - for ii, tc := range testCases { r := httptest.NewRecorder() SetStatus(r, tc.err) // Check that GetStatus returns a status when there is a mapping gotStatus, mappingFound := GetStatus(tc.err) - require.Equal(t, tc.expectMappingFound, mappingFound, fmt.Sprintf("test case %d failed", ii)) + s.Equal(tc.expectMappingFound, mappingFound, fmt.Sprintf("test case %d failed", ii)) + if mappingFound { + s.Equal(tc.expected, gotStatus, fmt.Sprintf("test case %d failed", ii)) + } + + s.Equal(tc.expected, r.Code, fmt.Sprintf("test case %d failed", ii)) + } + +} + +func (s *TestSuite) TestTranslateStatusCode() { + + testCases := []struct { + status HTTPStatus + expected simplerr.Code + expectMappingFound bool + }{ + {http.StatusNotFound, simplerr.CodeNotFound, true}, + {http.StatusRequestTimeout, simplerr.CodeCanceled, true}, + {23587253923, simplerr.CodeUnknown, false}, + } + + for ii, tc := range testCases { + gotCode, mappingFound := GetCode(tc.status) + + // Check that GetStatus returns a status when there is a mapping + s.Equal(tc.expectMappingFound, mappingFound, fmt.Sprintf("test case %d failed", ii)) if mappingFound { - require.Equal(t, tc.expected, gotStatus, fmt.Sprintf("test case %d failed", ii)) + s.Equal(tc.expected, gotCode, fmt.Sprintf("test case %d failed", ii)) } - require.Equal(t, tc.expected, r.Code, fmt.Sprintf("test case %d failed", ii)) + s.Equal(tc.expected, gotCode, fmt.Sprintf("test case %d failed", ii)) } }