From c6d69ab10cc3f2f039d216ac5cf6d3fce23933fe Mon Sep 17 00:00:00 2001 From: Calvin Lobo Date: Fri, 23 Dec 2022 09:19:21 -0500 Subject: [PATCH] Added test for reverse middleware adapter calling error handler if the pre-call middleware fails --- ecosystem/http/middleware.go | 10 +++-- ecosystem/http/middleware_test.go | 66 ++++++++++++++++++++++++++++++- 2 files changed, 71 insertions(+), 5 deletions(-) diff --git a/ecosystem/http/middleware.go b/ecosystem/http/middleware.go index 561cf74..42877fd 100644 --- a/ecosystem/http/middleware.go +++ b/ecosystem/http/middleware.go @@ -66,9 +66,13 @@ type middlewareReverseAdapter struct { h HandlerFunc } -// ServeHTTP satisfies the Handler interface but disregards returning any errors because it uses the http.Handler -func (a middlewareReverseAdapter) ServeHTTP(writer http.ResponseWriter, request *http.Request) { - _ = a.h.ServeHTTP(writer, request) +// ServeHTTP satisfies the Handler interface, calls the error handler but disregards returning any errors because +// it needs to satisfy the http.Handler interface +func (a middlewareReverseAdapter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + err := a.h.ServeHTTP(w, r) + if err != nil { + DefaultErrorHandler(w, r, err) + } } // MiddlewareReverseAdapter is an adapter for turning simplehttp compatible middleware to standard library middleware diff --git a/ecosystem/http/middleware_test.go b/ecosystem/http/middleware_test.go index 3a8a920..ddf7d05 100644 --- a/ecosystem/http/middleware_test.go +++ b/ecosystem/http/middleware_test.go @@ -206,8 +206,8 @@ func TestMiddlewareAdapter(t *testing.T) { stdlibHandler := http.HandlerFunc(NewHandlerAdapter(HandlerFunc(ep)).ServeHTTP) // Apply the middleware - ep := stdlibMiddlewarePre(stdlibHandler) - ep = stdlibMiddlewarePost(ep) + ep := stdlibMiddlewarePost(stdlibHandler) + ep = stdlibMiddlewarePre(ep) rec := httptest.NewRecorder() ep.ServeHTTP(rec, req) @@ -241,3 +241,65 @@ func TestMiddlewareAdapter(t *testing.T) { }) } + +func TestMiddlewareReverseAdapter(t *testing.T) { + + ep := func(writer http.ResponseWriter, request *http.Request) error { + request.Header.Add("call", timestamp()) + + n, err := writer.Write([]byte("done")) + require.NoError(t, err) + require.Greater(t, n, 0) + + return nil + } + + req, err := http.NewRequest("GET", "url", nil) + require.NoError(t, err) + + t.Run("pre error", func(t *testing.T) { + // Convert the simplehttp middlewares to a standard lib middleware + stdlibMiddlewarePre := MiddlewareReverseAdapter(ErrorCausingMiddleware(true)) + stdlibMiddlewarePost := MiddlewareReverseAdapter(PostCallMiddleware) + + // Convert the simplehttp handler func to a standard lib handler + stdlibHandler := http.HandlerFunc(NewHandlerAdapter(HandlerFunc(ep)).ServeHTTP) + + // Apply the middleware + ep := stdlibMiddlewarePost(stdlibHandler) + ep = stdlibMiddlewarePre(ep) + + rec := httptest.NewRecorder() + ep.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Result().StatusCode) + // There should be no response written because the endpoint should not have been called + _, err = rec.Body.ReadByte() + require.True(t, err == io.EOF) + + // The post call middleware should not be called + h := rec.Header().Get("post-call-middleware") + require.Empty(t, h, "post middleware should not have been called") + }) + + t.Run("post error", func(t *testing.T) { + + // Convert the simplehttp middlewares to a standard lib middleware + stdlibMiddlewarePre := MiddlewareReverseAdapter(ErrorCausingMiddleware(false)) + + // Convert the simplehttp handler func to a standard lib handler + stdlibHandler := http.HandlerFunc(NewHandlerAdapter(HandlerFunc(ep)).ServeHTTP) + + // Apply the middleware + ep := stdlibMiddlewarePre(stdlibHandler) + + rec := httptest.NewRecorder() + ep.ServeHTTP(rec, req) + + // Because the handler func has been called, the status gets written and we can only + // write the status once, so we cannot change the status from 200 even if we wanted. + require.Equal(t, http.StatusOK, rec.Result().StatusCode) + + }) + +}