Skip to content

Commit

Permalink
feat: add Status method to Writer struct and update test (#63)
Browse files Browse the repository at this point in the history
- Add a `Status` method to the `Writer` struct
- Modify the `TestWriter_Status` test function to include the new `Status` method

fixed by @zhyee

fixed #52
fixed #51

Signed-off-by: Bo-Yi Wu <appleboy.tw@gmail.com>
  • Loading branch information
appleboy authored Nov 25, 2023
1 parent f2805fd commit 7452411
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 0 deletions.
10 changes: 10 additions & 0 deletions writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ func (w *Writer) FreeBuffer() {
w.body = nil
}

// Status we must override Status func here,
// or the http status code returned by gin.Context.Writer.Status()
// will always be 200 in other custom gin middlewares.
func (w *Writer) Status() int {
if w.code == 0 || w.timeout {
return w.ResponseWriter.Status()
}
return w.code
}

func checkWriteHeaderCode(code int) {
if code < 100 || code > 999 {
panic(fmt.Sprintf("invalid http status code: %d", code))
Expand Down
36 changes: 36 additions & 0 deletions writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,13 @@ package timeout

import (
"fmt"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"

"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
)

Expand All @@ -21,3 +26,34 @@ func TestWriteHeader(t *testing.T) {
writer.WriteHeader(code2)
})
}

func TestWriter_Status(t *testing.T) {
r := gin.New()

r.Use(New(
WithTimeout(1*time.Second),
WithHandler(func(c *gin.Context) {
c.Next()
}),
WithResponse(testResponse),
))

r.Use(func(c *gin.Context) {
c.Next()
statusInMW := c.Writer.Status()
c.Request.Header.Set("X-Status-Code-MW-Set", strconv.Itoa(statusInMW))
t.Logf("[%s] %s %s %d\n", time.Now().Format(time.RFC3339), c.Request.Method, c.Request.URL, statusInMW)
})

r.GET("/test", func(c *gin.Context) {
c.Writer.WriteHeader(http.StatusInternalServerError)
})

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)

r.ServeHTTP(w, req)

assert.Equal(t, http.StatusInternalServerError, w.Code)
assert.Equal(t, strconv.Itoa(http.StatusInternalServerError), req.Header.Get("X-Status-Code-MW-Set"))
}

0 comments on commit 7452411

Please sign in to comment.