From f6c3af3cf8dd050302864fabb80d6eb0eb1852fb Mon Sep 17 00:00:00 2001 From: zhangyi Date: Wed, 15 Feb 2023 22:09:09 +0800 Subject: [PATCH] Fix bug #47: the status code returned by gin.Context.Writer.Status() is incorrect in other custom gin middlewares. --- timeout_test.go | 32 ++++++++++++++++++++++++++++++++ writer.go | 10 ++++++++++ 2 files changed, 42 insertions(+) diff --git a/timeout_test.go b/timeout_test.go index 63aa7bd..5e5784c 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/http/httptest" + "strconv" "testing" "time" @@ -100,3 +101,34 @@ func TestPanic(t *testing.T) { assert.Equal(t, http.StatusInternalServerError, w.Code) assert.Equal(t, "", w.Body.String()) } + +func TestWriter_Status(t *testing.T) { + r := gin.New() + + r.Use(New( + WithTimeout(1*time.Second), + WithHandler(func(c *gin.Context) { + c.Next() + }), + WithResponse(testResponse), + )) + + r.Use(func(c *gin.Context) { + c.Next() + statusInMW := c.Writer.Status() + c.Request.Header.Set("X-Status-Code-MW-Set", strconv.Itoa(statusInMW)) + t.Logf("[%s] %s %s %d\n", time.Now().Format(time.RFC3339), c.Request.Method, c.Request.URL, statusInMW) + }) + + r.GET("/test", func(c *gin.Context) { + c.Writer.WriteHeader(http.StatusInternalServerError) + }) + + w := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/test", nil) + + r.ServeHTTP(w, req) + + assert.Equal(t, w.Result().StatusCode, http.StatusInternalServerError) + assert.Equal(t, strconv.Itoa(http.StatusInternalServerError), req.Header.Get("X-Status-Code-MW-Set")) +} diff --git a/writer.go b/writer.go index d0cb79b..19301b8 100644 --- a/writer.go +++ b/writer.go @@ -55,6 +55,16 @@ func (w *Writer) writeHeader(code int) { w.code = code } +// Status we must override Status func here, +// or the http status code returned by gin.Context.Writer.Status() +// will always be 200 in other custom gin middlewares. +func (w *Writer) Status() int { + if w.code == 0 || w.timeout { + return w.ResponseWriter.Status() + } + return w.code +} + // Header will get response headers func (w *Writer) Header() http.Header { return w.headers