diff --git a/path.go b/path.go index d1004da..1ca5c66 100644 --- a/path.go +++ b/path.go @@ -38,6 +38,17 @@ func (p *Path) UseAfter(fns ...Middleware) *Path { return p } +// UseFinal registers the given middlewares to be executed in the order in which they are added, +// after the view or group has been executed. These middlewares will always be executed, +// even if a previous middleware or the view/group returned a response. +func (p *Path) UseFinal(fns ...FinalMiddleware) *Path { + p.middlewares.Final = append(p.middlewares.Final, fns...) + + p.router.handlePath(p) + + return p +} + // SkipMiddlewares registers the middlewares that you want to skip only when executing the view. func (p *Path) SkipMiddlewares(fns ...Middleware) *Path { p.middlewares.Skip = append(p.middlewares.Skip, fns...) diff --git a/path_test.go b/path_test.go index 3237b14..b319b00 100644 --- a/path_test.go +++ b/path_test.go @@ -96,6 +96,24 @@ func TestPath_UseAfter(t *testing.T) { assertHandle(t, p) } +func TestPath_UseFinal(t *testing.T) { + finalMiddlewareFns := []FinalMiddleware{ + func(ctx *RequestCtx) { + }, + func(ctx *RequestCtx) { + }, + } + + p := newTestPath() + p.UseFinal(finalMiddlewareFns...) + + if len(p.middlewares.Final) != len(middlewareFns) { + t.Errorf("Final middlewares are not registered") + } + + assertHandle(t, p) +} + func TestPath_SkipMiddlewares(t *testing.T) { p := newTestPath() p.SkipMiddlewares(middlewareFns...) diff --git a/router.go b/router.go index be77ba0..3fd5261 100644 --- a/router.go +++ b/router.go @@ -72,6 +72,9 @@ func (r *Router) buildMiddlewares(m Middlewares) Middlewares { m2.Skip = append(m2.Skip, m.Skip...) m2.Skip = append(m2.Skip, r.middlewares.Skip...) + m2.Final = append(m2.Final, m.Final...) + m2.Final = append(m2.Final, r.middlewares.Final...) + if r.parent != nil { return r.parent.buildMiddlewares(m2) } @@ -112,12 +115,7 @@ func (r *Router) handler(fn View, middle Middlewares) fasthttp.RequestHandler { for i := 0; i < chainLen; i++ { if err := chain[i](actx); err != nil { - statusCode := actx.Response.Header.StatusCode() - if statusCode == fasthttp.StatusOK { - statusCode = fasthttp.StatusInternalServerError - } - - r.errorView(actx, err, statusCode) + r.handleMiddlewareError(actx, err) break } else if !actx.next { @@ -127,10 +125,23 @@ func (r *Router) handler(fn View, middle Middlewares) fasthttp.RequestHandler { actx.next = false } + for _, final := range middle.Final { + final(actx) + } + ReleaseRequestCtx(actx) } } +func (r *Router) handleMiddlewareError(ctx *RequestCtx, err error) { + statusCode := ctx.Response.Header.StatusCode() + if statusCode == fasthttp.StatusOK { + statusCode = fasthttp.StatusInternalServerError + } + + r.errorView(ctx, err, statusCode) +} + func (r *Router) handlePath(p *Path) { isOPTIONS := p.method == fasthttp.MethodOptions @@ -220,6 +231,15 @@ func (r *Router) UseAfter(fns ...Middleware) *Router { return r } +// UseFinal registers the given middlewares to be executed in the order in which they are added, +// after the view or group has been executed. These middlewares will always be executed, +// even if a previous middleware or the view/group returned a response. +func (r *Router) UseFinal(fns ...FinalMiddleware) *Router { + r.middlewares.Final = append(r.middlewares.Final, fns...) + + return r +} + // SkipMiddlewares registers the middlewares that you want to skip when executing the view or group. func (r *Router) SkipMiddlewares(fns ...Middleware) *Router { r.middlewares.Skip = append(r.middlewares.Skip, fns...) diff --git a/router_test.go b/router_test.go index fb61399..ed61066 100644 --- a/router_test.go +++ b/router_test.go @@ -178,10 +178,12 @@ func TestRouter_buildMiddlewares(t *testing.T) { middleware1 := func(ctx *RequestCtx) error { return ctx.Next() } middleware2 := func(ctx *RequestCtx) error { return ctx.Next() } middleware3 := func(ctx *RequestCtx) error { return ctx.Next() } + middleware4 := func(ctx *RequestCtx) {} middle := Middlewares{ Before: []Middleware{middleware1, middleware2}, After: []Middleware{middleware3}, + Final: []FinalMiddleware{middleware4}, } m := Middlewares{ Skip: []Middleware{middleware1}, @@ -208,6 +210,10 @@ func TestRouter_buildMiddlewares(t *testing.T) { if len(result.After) != wantAfterLen { t.Errorf("Middlewares.After length == %d, want %d", len(result.After), wantAfterLen) } + + if wantFinalLen := len(middle.Final); len(result.Final) != wantFinalLen { + t.Errorf("Middlewares.Final length == %d, want %d", len(result.Final), wantFinalLen) + } } } @@ -228,6 +234,9 @@ func TestRouter_handlerExecutionChain(t *testing.T) { //nolint:funlen "viewAfter": 0, "groupAfter": 0, "globalAfter": 0, + "viewFinal": 0, + "groupFinal": 0, + "globalFinal": 0, } wantOrder := map[string]int{ @@ -237,6 +246,9 @@ func TestRouter_handlerExecutionChain(t *testing.T) { //nolint:funlen "viewAfter": 4, "groupAfter": 5, "globalAfter": 6, + "viewFinal": 7, + "groupFinal": 8, + "globalFinal": 9, } index := 0 @@ -264,6 +276,10 @@ func TestRouter_handlerExecutionChain(t *testing.T) { //nolint:funlen return ctx.Next() }) + s.UseFinal(func(ctx *RequestCtx) { + index++ + callOrder["globalFinal"] = index + }) v1 := s.NewGroupPath("/v1") v1.UseBefore(func(ctx *RequestCtx) error { @@ -278,6 +294,10 @@ func TestRouter_handlerExecutionChain(t *testing.T) { //nolint:funlen return ctx.Next() }, skipMiddlewareGroup) + v1.UseFinal(func(ctx *RequestCtx) { + index++ + callOrder["groupFinal"] = index + }) v1.SkipMiddlewares(skipMiddlewareGlobal) @@ -295,6 +315,9 @@ func TestRouter_handlerExecutionChain(t *testing.T) { //nolint:funlen callOrder["viewAfter"] = index return ctx.Next() + }).UseFinal(func(ctx *RequestCtx) { + index++ + callOrder["viewFinal"] = index }).SkipMiddlewares(skipMiddlewareGroup) ctx := new(fasthttp.RequestCtx) @@ -362,12 +385,15 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares int afterViewMiddlewares int afterMiddlewares int + finalViewMiddlewares int + finalMiddlewares int } type args struct { viewFn View before []Middleware after []Middleware + final []FinalMiddleware middlewares Middlewares } @@ -398,6 +424,12 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx return ctx.Next() }, } + final := []FinalMiddleware{ + func(ctx *RequestCtx) { + handlerCounter.finalMiddlewares++ + }, + } + middlewares := Middlewares{ Before: []Middleware{ func(ctx *RequestCtx) error { @@ -413,6 +445,11 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx return ctx.Next() }, }, + Final: []FinalMiddleware{ + func(ctx *RequestCtx) { + handlerCounter.finalViewMiddlewares++ + }, + }, } tests := []struct { @@ -426,6 +463,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx viewFn: viewFn, before: before, after: after, + final: final, middlewares: middlewares, }, want: want{ @@ -436,6 +474,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares: len(middlewares.Before), afterViewMiddlewares: len(middlewares.After), afterMiddlewares: len(after), + finalViewMiddlewares: len(middlewares.Final), + finalMiddlewares: len(final), }, }, }, @@ -452,6 +492,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx }, }, after: after, + final: final, middlewares: middlewares, }, want: want{ @@ -462,6 +503,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares: len(middlewares.Before), afterViewMiddlewares: len(middlewares.After), afterMiddlewares: len(after), + finalViewMiddlewares: len(middlewares.Final), + finalMiddlewares: len(final), }, }, }, @@ -473,6 +516,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx }, before: before, after: after, + final: final, middlewares: middlewares, }, want: want{ @@ -483,6 +527,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares: len(middlewares.Before), afterViewMiddlewares: 0, afterMiddlewares: 0, + finalViewMiddlewares: len(middlewares.Final), + finalMiddlewares: len(final), }, }, }, @@ -498,6 +544,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx }, }, after: after, + final: final, middlewares: middlewares, }, want: want{ @@ -508,6 +555,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares: 0, afterViewMiddlewares: 0, afterMiddlewares: 0, + finalViewMiddlewares: len(middlewares.Final), + finalMiddlewares: len(final), }, }, }, @@ -517,6 +566,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx viewFn: viewFn, before: before, after: after, + final: final, middlewares: Middlewares{ Before: []Middleware{ func(ctx *RequestCtx) error { @@ -532,6 +582,11 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx return ctx.Next() }, }, + Final: []FinalMiddleware{ + func(ctx *RequestCtx) { + handlerCounter.finalViewMiddlewares++ + }, + }, }, }, want: want{ @@ -542,6 +597,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares: 1, afterViewMiddlewares: 0, afterMiddlewares: 0, + finalViewMiddlewares: len(middlewares.Final), + finalMiddlewares: len(final), }, }, }, @@ -551,6 +608,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx viewFn: viewFn, before: before, after: after, + final: final, middlewares: Middlewares{ Before: []Middleware{ func(ctx *RequestCtx) error { @@ -566,6 +624,11 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx return ctx.ErrorResponse(err, fasthttp.StatusBadRequest) }, }, + Final: []FinalMiddleware{ + func(ctx *RequestCtx) { + handlerCounter.finalViewMiddlewares++ + }, + }, }, }, want: want{ @@ -576,6 +639,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares: 1, afterViewMiddlewares: 1, afterMiddlewares: 0, + finalViewMiddlewares: len(middlewares.Final), + finalMiddlewares: len(final), }, }, }, @@ -591,6 +656,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx return ctx.ErrorResponse(err, fasthttp.StatusBadRequest) }, }, + final: final, middlewares: middlewares, }, want: want{ @@ -601,6 +667,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares: len(middlewares.Before), afterViewMiddlewares: len(middlewares.After), afterMiddlewares: 1, + finalViewMiddlewares: len(middlewares.Final), + finalMiddlewares: len(final), }, }, }, @@ -616,6 +684,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx }, }, after: after, + final: final, middlewares: middlewares, }, want: want{ @@ -626,6 +695,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx beforeViewMiddlewares: 0, afterViewMiddlewares: 0, afterMiddlewares: 0, + finalViewMiddlewares: len(middlewares.Final), + finalMiddlewares: len(final), }, }, }, @@ -642,6 +713,8 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx handlerCounter.beforeViewMiddlewares = 0 handlerCounter.afterViewMiddlewares = 0 handlerCounter.afterMiddlewares = 0 + handlerCounter.finalViewMiddlewares = 0 + handlerCounter.finalMiddlewares = 0 t.Run(tt.name, func(t *testing.T) { t.Helper() @@ -656,6 +729,7 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx }) r.UseBefore(tt.args.before...) r.UseAfter(tt.args.after...) + r.UseFinal(tt.args.final...) r.Path(method, path, tt.args.viewFn).Middlewares(tt.args.middlewares) ctx := new(fasthttp.RequestCtx) @@ -690,6 +764,16 @@ func TestRouter_handler(t *testing.T) { //nolint:funlen,maintidx t.Errorf("After view call counter = %v, want %v", handlerCounter.afterViewMiddlewares, tt.want.counter.afterViewMiddlewares) } + + if handlerCounter.finalMiddlewares != tt.want.counter.finalMiddlewares { + t.Errorf("Final middlewares call counter = %v, want %v", handlerCounter.finalMiddlewares, + tt.want.counter.finalMiddlewares) + } + + if handlerCounter.finalViewMiddlewares != tt.want.counter.finalViewMiddlewares { + t.Errorf("Final view call counter = %v, want %v", handlerCounter.finalViewMiddlewares, + tt.want.counter.finalViewMiddlewares) + } }) } } diff --git a/types.go b/types.go index 28d1c5d..3d6a98a 100644 --- a/types.go +++ b/types.go @@ -546,11 +546,15 @@ type PanicView func(*RequestCtx, interface{}) // Middleware must process all incoming requests before/after defined views. type Middleware View +// FinalMiddleware must process all incoming requests after the other middlewares/view. +type FinalMiddleware func(*RequestCtx) + // Middlewares is a collection of middlewares with the order of execution and which to skip. type Middlewares struct { Before []Middleware After []Middleware Skip []Middleware + Final []FinalMiddleware } // PathRewriteFunc must return new request path based on arbitrary ctx