Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: optimize recover logic #123

Merged
merged 2 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions middleware_timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ import (
"net/http"
"time"

"github.com/gin-gonic/gin"

contractshttp "github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/errors"
)
Expand All @@ -24,12 +22,7 @@ func Timeout(timeout time.Duration) contractshttp.Middleware {
go func() {
defer func() {
if err := recover(); err != nil {
if globalRecoverCallback != nil {
globalRecoverCallback(ctx, err)
} else {
LogFacade.Error(err)
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Server Error"})
}
globalRecoverCallback(ctx, err)
}

close(done)
Expand Down
73 changes: 41 additions & 32 deletions middleware_timeout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
mocksconfig "github.com/goravel/framework/mocks/config"
mockslog "github.com/goravel/framework/mocks/log"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

Expand All @@ -35,45 +36,53 @@ func TestTimeoutMiddleware(t *testing.T) {
panic(1)
})

w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/timeout", nil)
require.NoError(t, err)
t.Run("timeout request", func(t *testing.T) {
w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/timeout", nil)
require.NoError(t, err)

route.ServeHTTP(w, req)
assert.Equal(t, http.StatusGatewayTimeout, w.Code)
route.ServeHTTP(w, req)
assert.Equal(t, http.StatusGatewayTimeout, w.Code)
})

w = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/normal", nil)
require.NoError(t, err)
t.Run("normal request", func(t *testing.T) {
w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/normal", nil)
require.NoError(t, err)

route.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "normal", w.Body.String())
route.ServeHTTP(w, req)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "normal", w.Body.String())
})

// Test with default recover callback
mockLog := mockslog.NewLog(t)
mockLog.EXPECT().Error(1).Once()
LogFacade = mockLog
t.Run("panic with default recover", func(t *testing.T) {
mockLog := mockslog.NewLog(t)
mockLog.EXPECT().WithContext(mock.Anything).Return(mockLog).Once()
mockLog.EXPECT().Request(mock.Anything).Return(mockLog).Once()
mockLog.EXPECT().Error(1).Once()
LogFacade = mockLog

w = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)
w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)

route.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "{\"error\":\"Internal Server Error\"}", w.Body.String())
route.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Empty(t, w.Body.String())
})

// Test with custom recover callback
globalRecover := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Panic"})
}
route.Recover(globalRecover)
t.Run("panic with custom recover", func(t *testing.T) {
globalRecover := func(ctx contractshttp.Context, err any) {
ctx.Request().AbortWithStatusJson(http.StatusInternalServerError, gin.H{"error": "Internal Panic"})
}
route.Recover(globalRecover)

w = httptest.NewRecorder()
req, err = http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)
w := httptest.NewRecorder()
req, err := http.NewRequest("GET", "/panic", nil)
require.NoError(t, err)

route.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "{\"error\":\"Internal Panic\"}", w.Body.String())
route.ServeHTTP(w, req)
assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, "{\"error\":\"Internal Panic\"}", w.Body.String())
})
}
29 changes: 16 additions & 13 deletions route.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,17 @@ import (
"github.com/gin-gonic/gin"
"github.com/gin-gonic/gin/render"
"github.com/goravel/framework/contracts/config"
httpcontract "github.com/goravel/framework/contracts/http"
contractshttp "github.com/goravel/framework/contracts/http"
"github.com/goravel/framework/contracts/route"
"github.com/goravel/framework/support"
"github.com/goravel/framework/support/color"
"github.com/savioxavier/termlink"
)

var globalRecoverCallback func(ctx httpcontract.Context, err any)
var globalRecoverCallback func(ctx contractshttp.Context, err any) = func(ctx contractshttp.Context, err any) {
LogFacade.WithContext(ctx).Request(ctx.Request()).Error(err)
ctx.Request().AbortWithStatus(http.StatusInternalServerError)
}

type Route struct {
route.Router
Expand Down Expand Up @@ -70,31 +73,31 @@ func NewRoute(config config.Config, parameters map[string]any) (*Route, error) {
config,
engine.Group("/"),
"",
[]httpcontract.Middleware{},
[]httpcontract.Middleware{ResponseMiddleware()},
[]contractshttp.Middleware{},
[]contractshttp.Middleware{ResponseMiddleware()},
),
config: config,
instance: engine,
}, nil
}

func (r *Route) Fallback(handler httpcontract.HandlerFunc) {
func (r *Route) Fallback(handler contractshttp.HandlerFunc) {
r.instance.NoRoute(handlerToGinHandler(handler))
}

func (r *Route) GlobalMiddleware(middlewares ...httpcontract.Middleware) {
func (r *Route) GlobalMiddleware(middlewares ...contractshttp.Middleware) {
timeout := time.Duration(r.config.GetInt("http.request_timeout", 3)) * time.Second
defaultMiddlewares := []httpcontract.Middleware{
defaultMiddlewares := []contractshttp.Middleware{
Cors(), Tls(), Timeout(timeout),
}
middlewares = append(defaultMiddlewares, middlewares...)
r.setMiddlewares(middlewares)
}

func (r *Route) Recover(callback func(ctx httpcontract.Context, err any)) {
func (r *Route) Recover(callback func(ctx contractshttp.Context, err any)) {
globalRecoverCallback = callback
r.setMiddlewares([]httpcontract.Middleware{
func(ctx httpcontract.Context) {
r.setMiddlewares([]contractshttp.Middleware{
func(ctx contractshttp.Context) {
defer func() {
if err := recover(); err != nil {
callback(ctx, err)
Expand Down Expand Up @@ -246,13 +249,13 @@ func (r *Route) outputRoutes() {
}
}

func (r *Route) setMiddlewares(middlewares []httpcontract.Middleware) {
func (r *Route) setMiddlewares(middlewares []contractshttp.Middleware) {
r.instance.Use(middlewaresToGinHandlers(middlewares)...)
r.Router = NewGroup(
r.config,
r.instance.Group("/"),
"",
[]httpcontract.Middleware{},
[]httpcontract.Middleware{ResponseMiddleware()},
[]contractshttp.Middleware{},
[]contractshttp.Middleware{ResponseMiddleware()},
)
}
Loading