Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport/http): remove wrapper.w and add Unwrap method for acce… #3189

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions transport/http/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,24 @@ type Context interface {
Reset(http.ResponseWriter, *http.Request)
}

type responseWriter struct {
type wrappedWriter struct {
code int
w http.ResponseWriter
http.ResponseWriter
}

func (w *responseWriter) reset(res http.ResponseWriter) {
w.w = res
w.code = http.StatusOK
}
func (w *responseWriter) Header() http.Header { return w.w.Header() }
func (w *responseWriter) WriteHeader(statusCode int) { w.code = statusCode }
func (w *responseWriter) Write(data []byte) (int, error) {
w.w.WriteHeader(w.code)
return w.w.Write(data)
func (w *wrappedWriter) WriteHeader(statusCode int) { w.code = statusCode }
func (w *wrappedWriter) Write(data []byte) (int, error) {
w.ResponseWriter.WriteHeader(w.code)
return w.ResponseWriter.Write(data)
}

// Unwrap is a escape hatch for accessing wrapped http.ResponseWriter.
func (w *wrappedWriter) Unwrap() http.ResponseWriter { return w.ResponseWriter }

type wrapper struct {
router *Router
req *http.Request
res http.ResponseWriter
w responseWriter
}

func (c *wrapper) Header() http.Header {
Expand Down Expand Up @@ -104,12 +101,11 @@ func (c *wrapper) Returns(v interface{}, err error) error {
if err != nil {
return err
}
return c.router.srv.enc(&c.w, c.req, v)
return c.router.srv.enc(c.res, c.req, v)
}

func (c *wrapper) Result(code int, v interface{}) error {
c.w.WriteHeader(code)
return c.router.srv.enc(&c.w, c.req, v)
return c.router.srv.enc(&wrappedWriter{code, c.res}, c.req, v)
}

func (c *wrapper) JSON(code int, v interface{}) error {
Expand Down Expand Up @@ -152,7 +148,6 @@ func (c *wrapper) Stream(code int, contentType string, rd io.Reader) error {
}

func (c *wrapper) Reset(res http.ResponseWriter, req *http.Request) {
c.w.reset(res)
c.res = res
c.req = req
}
Expand Down
87 changes: 76 additions & 11 deletions transport/http/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"errors"
"io"
"net/http"
"net/http/httptest"
"net/url"
Expand All @@ -19,7 +20,6 @@ func TestContextHeader(t *testing.T) {
router: testRouter,
req: &http.Request{Header: map[string][]string{"name": {"kratos"}}},
res: nil,
w: responseWriter{},
}
h := w.Header()
if !reflect.DeepEqual(h, http.Header{"name": {"kratos"}}) {
Expand All @@ -32,7 +32,6 @@ func TestContextForm(t *testing.T) {
router: testRouter,
req: &http.Request{Header: map[string][]string{"name": {"kratos"}}, Method: http.MethodPost},
res: nil,
w: responseWriter{},
}
form := w.Form()
if !reflect.DeepEqual(form, url.Values{}) {
Expand All @@ -43,7 +42,6 @@ func TestContextForm(t *testing.T) {
router: testRouter,
req: &http.Request{Form: map[string][]string{"name": {"kratos"}}},
res: nil,
w: responseWriter{},
}
form = w.Form()
if !reflect.DeepEqual(form, url.Values{"name": {"kratos"}}) {
Expand All @@ -56,7 +54,6 @@ func TestContextQuery(t *testing.T) {
router: testRouter,
req: &http.Request{URL: &url.URL{Scheme: "https", Host: "github.com", Path: "go-kratos/kratos", RawQuery: "page=1"}, Method: http.MethodPost},
res: nil,
w: responseWriter{},
}
q := w.Query()
if !reflect.DeepEqual(q, url.Values{"page": {"1"}}) {
Expand All @@ -70,7 +67,6 @@ func TestContextRequest(t *testing.T) {
router: testRouter,
req: req,
res: nil,
w: responseWriter{},
}
res := w.Request()
if !reflect.DeepEqual(res, req) {
Expand All @@ -84,7 +80,6 @@ func TestContextResponse(t *testing.T) {
router: &Router{srv: &Server{enc: DefaultResponseEncoder}},
req: &http.Request{Method: http.MethodPost},
res: res,
w: responseWriter{200, res},
}
if !reflect.DeepEqual(w.Response(), res) {
t.Errorf("expected %v, got %v", res, w.Response())
Expand All @@ -100,12 +95,86 @@ func TestContextResponse(t *testing.T) {
}
}

func TestContextResult(t *testing.T) {
testCases := []struct {
name string
enc EncodeResponseFunc
code int
header string
}{
{
name: "normal",
enc: func(rw http.ResponseWriter, r *http.Request, v interface{}) error {
rw.Header().Set("X-Foo", "foo")
_, err := rw.Write([]byte(v.(string)))
return err
},
code: 400,
header: "foo",
},
{
name: "writeHeader",
enc: func(rw http.ResponseWriter, r *http.Request, v interface{}) error {
rw.Header().Set("X-Foo", "foo")
rw.WriteHeader(500)
_, err := rw.Write([]byte(v.(string)))
return err
},
code: 500,
header: "foo",
},
{
name: "unwrap",
enc: func(rw http.ResponseWriter, r *http.Request, v interface{}) error {
u, ok := rw.(interface{ Unwrap() http.ResponseWriter })
if !ok {
t.Fatal("can not wrap http.ResponseWriter")
}

w := u.Unwrap()
w.Header().Set("X-Foo", "foo")
w.WriteHeader(500)

_, err := w.Write([]byte(v.(string)))
return err
},
code: 500,
header: "foo",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
res := httptest.NewRecorder()
w := wrapper{
router: &Router{srv: &Server{enc: tc.enc}},
req: nil,
res: res,
}
err := w.Result(400, "body")
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}

resp := res.Result()
defer resp.Body.Close()
if resp.StatusCode != tc.code {
t.Fatalf("expected %d, got %d", tc.code, resp.StatusCode)
}
if s := resp.Header.Get("X-Foo"); s != tc.header {
t.Fatalf("expected %q, got %q", tc.header, s)
}
if bs, _ := io.ReadAll(res.Body); string(bs) != "body" {
t.Fatalf("expected %s, got: %s", "body", bs)
}
})
}
}

func TestContextBindQuery(t *testing.T) {
w := wrapper{
router: testRouter,
req: &http.Request{URL: &url.URL{Scheme: "https", Host: "go-kratos-dev", RawQuery: "page=2"}},
res: nil,
w: responseWriter{},
}
type BindQuery struct {
Page int `json:"page"`
Expand All @@ -125,7 +194,6 @@ func TestContextBindForm(t *testing.T) {
router: testRouter,
req: &http.Request{URL: &url.URL{Scheme: "https", Host: "go-kratos-dev"}, Form: map[string][]string{"page": {"2"}}},
res: nil,
w: responseWriter{},
}
type BindForm struct {
Page int `json:"page"`
Expand All @@ -146,7 +214,6 @@ func TestContextResponseReturn(t *testing.T) {
router: testRouter,
req: nil,
res: writer,
w: responseWriter{},
}
err := w.JSON(200, "success")
if err != nil {
Expand Down Expand Up @@ -179,7 +246,6 @@ func TestContextCtx(t *testing.T) {
router: testRouter,
req: req,
res: nil,
w: responseWriter{},
}
_, ok := w.Deadline()
if !ok {
Expand All @@ -202,7 +268,6 @@ func TestContextCtx(t *testing.T) {
router: &Router{srv: &Server{enc: DefaultResponseEncoder}},
req: nil,
res: nil,
w: responseWriter{},
}
_, ok = w.Deadline()
if ok {
Expand Down
Loading