Skip to content

Commit

Permalink
Create a temporary context for the first middleware to run for not se…
Browse files Browse the repository at this point in the history
…nding content to client if the middleware fail
  • Loading branch information
jmaitrehenry committed Oct 2, 2021
1 parent 61f2247 commit d0ec91b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 8 deletions.
17 changes: 15 additions & 2 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,23 @@ func Mixed() func(handler1, handler2 echo.MiddlewareFunc) echo.MiddlewareFunc {
a2 := handler2(next)

return func(c echo.Context) error {
if a1(c) == nil {
tempContext := copyContext(c)
// tempContext.Set(ContextUserInfoKey, c.Get(ContextUserInfoKey))
defer tempContext.Echo().ReleaseContext(tempContext)
if a1(tempContext) == nil {
copyResponse(c, tempContext)

return nil
}
return a2(c)

// Try the second middleware
err := a2(c)
if err != nil {
// Return the first middleware error
copyResponse(c, tempContext)
c.Response().WriteHeader(tempContext.Response().Status)
}
return err
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
)

func TestMixedAuthenticationSucceedsOnFirstAuth(t *testing.T) {
auth1 := succeedingMiddleware("hello")
auth1 := succeedingMiddleware(204)
auth2 := failingMiddleware(errors.New("boom"))

ctx := newContext()
Expand All @@ -22,12 +22,12 @@ func TestMixedAuthenticationSucceedsOnFirstAuth(t *testing.T) {
err := handlerFunc(ctx)

assert.NoError(t, err)
assert.Equal(t, "hello", ctx.Get("KEY"))
assert.Equal(t, 204, ctx.Response().Status)
}

func TestMixedAuthenticationSucceedsOnSecondAuth(t *testing.T) {
auth1 := failingMiddleware(errors.New("boom"))
auth2 := succeedingMiddleware("hello2")
auth2 := succeedingMiddleware(204)

ctx := newContext()
md := Mixed()(auth1, auth2)
Expand All @@ -36,7 +36,7 @@ func TestMixedAuthenticationSucceedsOnSecondAuth(t *testing.T) {
err := handlerFunc(ctx)

assert.NoError(t, err)
assert.Equal(t, "hello2", ctx.Get("KEY"))
assert.Equal(t, 204, ctx.Response().Status)
}

func TestMixedAuthenticationFailsOnBothFailedAuths(t *testing.T) {
Expand All @@ -52,10 +52,10 @@ func TestMixedAuthenticationFailsOnBothFailedAuths(t *testing.T) {
assert.Equal(t, "boom2", err.Error())
}

func succeedingMiddleware(succeedingValue string) echo.MiddlewareFunc {
func succeedingMiddleware(status int) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Set("KEY", succeedingValue)
c.NoContent(status)
return nil
}
}
Expand Down
24 changes: 24 additions & 0 deletions tempresponsewriter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package mixed

import "net/http"

// tempResponseWriter implement the http.ResponseWriter interface
// It used when you need to temporary hold the content and use it later
type tempResponseWriter struct {
header http.Header
StatusCode int
Content []byte
}

func (rw *tempResponseWriter) Header() http.Header {
return rw.header
}

func (rw *tempResponseWriter) Write(content []byte) (int, error) {
rw.Content = append(rw.Content, content...)
return len(content), nil
}

func (rw *tempResponseWriter) WriteHeader(code int) {
rw.StatusCode = code
}
41 changes: 41 additions & 0 deletions utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package mixed

import (
"net/http"

echo "github.com/labstack/echo/v4"
)

// copyContext create a new echo.Context with the same Request, Path, Params and Handler
// but with a Response that contain an in-memory ResponseWriter
func copyContext(c echo.Context) echo.Context {
cc := c.Echo().AcquireContext()
cc.SetRequest(c.Request())
cc.SetPath(c.Path())
cc.SetParamNames(c.ParamNames()...)
cc.SetParamValues(c.ParamValues()...)
cc.SetHandler(c.Handler())

rw := tempResponseWriter{
header: make(http.Header),
Content: []byte{},
}
resp := echo.NewResponse(&rw, c.Echo())
cc.SetResponse(resp)
return cc
}

// copyResponse copy c2 headers and content into c1
func copyResponse(c1, c2 echo.Context) {
for k, v := range c2.Response().Header() {
for _, vv := range v {
c1.Response().Header().Set(k, vv)
}
}

if c2.Response().Status > 0 {
c1.Response().WriteHeader(c2.Response().Status)
}

c1.Response().Write(c2.Response().Writer.(*tempResponseWriter).Content)
}

0 comments on commit d0ec91b

Please sign in to comment.