From 5c80761dc0c154b54116f0f36245ffa8b6d0cf6d Mon Sep 17 00:00:00 2001 From: Bowen Date: Mon, 30 Dec 2024 10:41:50 +0800 Subject: [PATCH 1/2] chore: optimize recover logic --- middleware_timeout.go | 9 +-------- route.go | 29 ++++++++++++++++------------- 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/middleware_timeout.go b/middleware_timeout.go index 5733ec6..46cfe62 100644 --- a/middleware_timeout.go +++ b/middleware_timeout.go @@ -5,8 +5,6 @@ import ( "net/http" "time" - "github.com/gin-gonic/gin" - contractshttp "github.com/goravel/framework/contracts/http" "github.com/goravel/framework/errors" ) @@ -24,12 +22,7 @@ func Timeout(timeout time.Duration) contractshttp.Middleware { go func() { defer func() { if err := recover(); err != nil { - if globalRecoverCallback != nil { - globalRecoverCallback(ctx, err) - } else { - LogFacade.Error(err) - ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"}) - } + globalRecoverCallback(ctx, err) } close(done) diff --git a/route.go b/route.go index a9b6ddc..eb9eac1 100644 --- a/route.go +++ b/route.go @@ -12,14 +12,17 @@ import ( "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/render" "github.com/goravel/framework/contracts/config" - httpcontract "github.com/goravel/framework/contracts/http" + contractshttp "github.com/goravel/framework/contracts/http" "github.com/goravel/framework/contracts/route" "github.com/goravel/framework/support" "github.com/goravel/framework/support/color" "github.com/savioxavier/termlink" ) -var globalRecoverCallback func(ctx httpcontract.Context, err any) +var globalRecoverCallback func(ctx contractshttp.Context, err any) = func(ctx contractshttp.Context, err any) { + LogFacade.WithContext(ctx).Request(ctx.Request()).Error(err) + ctx.Request().AbortWithStatus(http.StatusInternalServerError) +} type Route struct { route.Router @@ -70,31 +73,31 @@ func NewRoute(config config.Config, parameters map[string]any) (*Route, error) { config, engine.Group("/"), "", - []httpcontract.Middleware{}, - []httpcontract.Middleware{ResponseMiddleware()}, + []contractshttp.Middleware{}, + []contractshttp.Middleware{ResponseMiddleware()}, ), config: config, instance: engine, }, nil } -func (r *Route) Fallback(handler httpcontract.HandlerFunc) { +func (r *Route) Fallback(handler contractshttp.HandlerFunc) { r.instance.NoRoute(handlerToGinHandler(handler)) } -func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) { +func (r *Route) GlobalMiddleware(middlewares ...contractshttp.Middleware) { timeout := time.Duration(r.config.GetInt("http.request_timeout", 3)) * time.Second - defaultMiddlewares := []httpcontract.Middleware{ + defaultMiddlewares := []contractshttp.Middleware{ Cors(), Tls(), Timeout(timeout), } middlewares = append(defaultMiddlewares, middlewares...) r.setMiddlewares(middlewares) } -func (r *Route) Recover(callback func(ctx httpcontract.Context, err any)) { +func (r *Route) Recover(callback func(ctx contractshttp.Context, err any)) { globalRecoverCallback = callback - r.setMiddlewares([]httpcontract.Middleware{ - func(ctx httpcontract.Context) { + r.setMiddlewares([]contractshttp.Middleware{ + func(ctx contractshttp.Context) { defer func() { if err := recover(); err != nil { callback(ctx, err) @@ -246,13 +249,13 @@ func (r *Route) outputRoutes() { } } -func (r *Route) setMiddlewares(middlewares []httpcontract.Middleware) { +func (r *Route) setMiddlewares(middlewares []contractshttp.Middleware) { r.instance.Use(middlewaresToGinHandlers(middlewares)...) r.Router = NewGroup( r.config, r.instance.Group("/"), "", - []httpcontract.Middleware{}, - []httpcontract.Middleware{ResponseMiddleware()}, + []contractshttp.Middleware{}, + []contractshttp.Middleware{ResponseMiddleware()}, ) } From 8d656da050daebe0e18de1e87ddca3ac3436bf80 Mon Sep 17 00:00:00 2001 From: Bowen Date: Mon, 30 Dec 2024 11:05:56 +0800 Subject: [PATCH 2/2] optimize --- middleware_timeout_test.go | 73 +++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 32 deletions(-) diff --git a/middleware_timeout_test.go b/middleware_timeout_test.go index ac40596..3b6f85f 100644 --- a/middleware_timeout_test.go +++ b/middleware_timeout_test.go @@ -11,6 +11,7 @@ import ( mocksconfig "github.com/goravel/framework/mocks/config" mockslog "github.com/goravel/framework/mocks/log" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -35,45 +36,53 @@ func TestTimeoutMiddleware(t *testing.T) { panic(1) }) - w := httptest.NewRecorder() - req, err := http.NewRequest("GET", "/timeout", nil) - require.NoError(t, err) + t.Run("timeout request", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/timeout", nil) + require.NoError(t, err) - route.ServeHTTP(w, req) - assert.Equal(t, http.StatusGatewayTimeout, w.Code) + route.ServeHTTP(w, req) + assert.Equal(t, http.StatusGatewayTimeout, w.Code) + }) - w = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/normal", nil) - require.NoError(t, err) + t.Run("normal request", func(t *testing.T) { + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/normal", nil) + require.NoError(t, err) - route.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "normal", w.Body.String()) + route.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "normal", w.Body.String()) + }) - // Test with default recover callback - mockLog := mockslog.NewLog(t) - mockLog.EXPECT().Error(1).Once() - LogFacade = mockLog + t.Run("panic with default recover", func(t *testing.T) { + mockLog := mockslog.NewLog(t) + mockLog.EXPECT().WithContext(mock.Anything).Return(mockLog).Once() + mockLog.EXPECT().Request(mock.Anything).Return(mockLog).Once() + mockLog.EXPECT().Error(1).Once() + LogFacade = mockLog - w = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/panic", nil) - require.NoError(t, err) + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/panic", nil) + require.NoError(t, err) - route.ServeHTTP(w, req) - assert.Equal(t, http.StatusInternalServerError, w.Code) - assert.Equal(t, "{\"error\":\"Internal Server Error\"}", w.Body.String()) + route.ServeHTTP(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Empty(t, w.Body.String()) + }) - // Test with custom recover callback - globalRecover := func(ctx contractshttp.Context, err any) { - ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Panic"}) - } - route.Recover(globalRecover) + t.Run("panic with custom recover", func(t *testing.T) { + globalRecover := func(ctx contractshttp.Context, err any) { + ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Panic"}) + } + route.Recover(globalRecover) - w = httptest.NewRecorder() - req, err = http.NewRequest("GET", "/panic", nil) - require.NoError(t, err) + w := httptest.NewRecorder() + req, err := http.NewRequest("GET", "/panic", nil) + require.NoError(t, err) - route.ServeHTTP(w, req) - assert.Equal(t, http.StatusInternalServerError, w.Code) - assert.Equal(t, "{\"error\":\"Internal Panic\"}", w.Body.String()) + route.ServeHTTP(w, req) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "{\"error\":\"Internal Panic\"}", w.Body.String()) + }) }