Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add WithForwardResponseRewriter to allow easier/more useful response control #4622

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions docs/docs/mapping/customizing_your_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ First, set up the gRPC-Gateway with the custom options:

```go
mux := runtime.NewServeMux(
runtime.WithMarshalerOption(runtime.MIMEWildcard, &ResponseWrapper{}),
runtime.WithForwardResponseOption(forwardResponse),
runtime.WithForwardResponseOption(setStatus),
runtime.WithForwardResponseRewriter(responseEnvelope),
)
```

Define the `forwardResponse` function to handle specific response types:
Define the `setStatus` function to handle specific response types:

```go
func forwardResponse(ctx context.Context, w http.ResponseWriter, m protoreflect.ProtoMessage) error {
func setStatus(ctx context.Context, w http.ResponseWriter, m protoreflect.ProtoMessage) error {
switch v := m.(type) {
case *pb.CreateUserResponse:
w.WriteHeader(http.StatusCreated)
Expand All @@ -342,32 +342,29 @@ func forwardResponse(ctx context.Context, w http.ResponseWriter, m protoreflect.
}
```

Create a custom marshaler to format the response data which utilizes the `JSONPb` marshaler as a fallback:
Define the `responseEnvelope` function to rewrite the response to a different type/shape:

```go
type ResponseWrapper struct {
runtime.JSONPb
}

func (c *ResponseWrapper) Marshal(data any) ([]byte, error) {
resp := data
func responseEnvelope(_ context.Context, response proto.Message) (interface{}, error) {
switch v := data.(type) {
case *pb.CreateUserResponse:
// wrap the response in a custom structure
resp = map[string]any{
return map[string]any{
"success": true,
"data": data,
}
}, nil
}
// otherwise, use the default JSON marshaller
return c.JSONPb.Marshal(resp)
return response, nil
}
```

In this setup:

- The `forwardResponse` function intercepts the response and formats it as needed.
- The `CustomPB` marshaller ensures that specific types of responses are wrapped in a custom structure before being sent to the client.
- The `setStatus` function intercepts the response and uses its type to send `201 Created` only when it sees `*pb.CreateUserResponse`.
- The `responseEnvelope` function ensures that specific types of responses are wrapped in a custom structure before being sent to the client.

❗ **NOTE:** Using `WithForwardResponseRewriter` is partially incompatible with OpenAPI annotations. Because response
rewriting happens at runtime, it is not possible to represent that in `protoc-gen-openapiv2` output.

## Error handler

Expand Down
32 changes: 14 additions & 18 deletions examples/internal/proto/examplepb/response_body_service.pb.gw.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 4 additions & 5 deletions protoc-gen-grpc-gateway/internal/gengateway/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Server(ctx context.Context,
}

{{ if $b.ResponseBody }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})}, mux.GetForwardResponseOptions()...)
{{ else }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
{{end}}
Expand Down Expand Up @@ -744,7 +744,7 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
{{end}}
{{else}}
{{ if $b.ResponseBody }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp}, mux.GetForwardResponseOptions()...)
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}{resp.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})}, mux.GetForwardResponseOptions()...)
{{ else }}
forward_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...)
{{end}}
Expand All @@ -759,12 +759,11 @@ func Register{{$svc.GetName}}{{$.RegisterFuncSuffix}}Client(ctx context.Context,
{{range $b := $m.Bindings}}
{{if $b.ResponseBody}}
type response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}} struct {
proto.Message
*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}}
}

func (m response_{{$svc.GetName}}_{{$m.GetName}}_{{$b.Index}}) XXX_ResponseBody() interface{} {
response := m.Message.(*{{$m.ResponseType.GoType $m.Service.File.GoPkg.Path}})
return {{$b.ResponseBody.AssignableExpr "response" $m.Service.File.GoPkg.Path}}
return {{$b.ResponseBody.AssignableExpr "m" $m.Service.File.GoPkg.Path}}
}
{{end}}
{{end}}
Expand Down
16 changes: 13 additions & 3 deletions runtime/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,36 @@ func HTTPError(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.R
func DefaultHTTPErrorHandler(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, r *http.Request, err error) {
// return Internal when Marshal failed
const fallback = `{"code": 13, "message": "failed to marshal error message"}`
const fallbackRewriter = `{"code": 13, "message": "failed to rewrite error message"}`

var customStatus *HTTPStatusError
if errors.As(err, &customStatus) {
err = customStatus.Err
}

s := status.Convert(err)
pb := s.Proto()

w.Header().Del("Trailer")
w.Header().Del("Transfer-Encoding")

contentType := marshaler.ContentType(pb)
respRw, err := mux.forwardResponseRewriter(ctx, s.Proto())
if err != nil {
grpclog.Errorf("Failed to rewrite error message %q: %v", s, err)
w.WriteHeader(http.StatusInternalServerError)
if _, err := io.WriteString(w, fallbackRewriter); err != nil {
grpclog.Errorf("Failed to write response: %v", err)
}
return
}

contentType := marshaler.ContentType(respRw)
w.Header().Set("Content-Type", contentType)

if s.Code() == codes.Unauthenticated {
w.Header().Set("WWW-Authenticate", s.Message())
}

buf, merr := marshaler.Marshal(pb)
buf, merr := marshaler.Marshal(respRw)
if merr != nil {
grpclog.Errorf("Failed to marshal error message %q: %v", s, merr)
w.WriteHeader(http.StatusInternalServerError)
Expand Down
47 changes: 36 additions & 11 deletions runtime/errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
statuspb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)

func TestDefaultHTTPError(t *testing.T) {
Expand All @@ -24,12 +25,14 @@ func TestDefaultHTTPError(t *testing.T) {
)

for i, spec := range []struct {
err error
status int
msg string
marshaler runtime.Marshaler
contentType string
details string
err error
status int
msg string
marshaler runtime.Marshaler
contentType string
details string
fordwardRespRewriter runtime.ForwardResponseRewriter
extractMessage func(*testing.T)
}{
{
err: errors.New("example error"),
Expand Down Expand Up @@ -70,23 +73,45 @@ func TestDefaultHTTPError(t *testing.T) {
contentType: "application/json",
msg: "Method Not Allowed",
},
{
err: status.Error(codes.InvalidArgument, "example error"),
status: http.StatusBadRequest,
marshaler: &runtime.JSONPb{},
contentType: "application/json",
msg: "bad request: example error",
fordwardRespRewriter: func(ctx context.Context, response proto.Message) (any, error) {
if s, ok := response.(*statuspb.Status); ok && strings.HasPrefix(s.Message, "example") {
return &statuspb.Status{
Code: s.Code,
Message: "bad request: " + s.Message,
Details: s.Details,
}, nil
}
return response, nil
},
},
} {
t.Run(strconv.Itoa(i), func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(ctx, "", "", nil) // Pass in an empty request to match the signature
mux := runtime.NewServeMux()
marshaler := &runtime.JSONPb{}
runtime.HTTPError(ctx, mux, marshaler, w, req, spec.err)

if got, want := w.Header().Get("Content-Type"), "application/json"; got != want {
opts := []runtime.ServeMuxOption{}
if spec.fordwardRespRewriter != nil {
opts = append(opts, runtime.WithForwardResponseRewriter(spec.fordwardRespRewriter))
}
mux := runtime.NewServeMux(opts...)

runtime.HTTPError(ctx, mux, spec.marshaler, w, req, spec.err)

if got, want := w.Header().Get("Content-Type"), spec.contentType; got != want {
t.Errorf(`w.Header().Get("Content-Type") = %q; want %q; on spec.err=%v`, got, want, spec.err)
}
if got, want := w.Code, spec.status; got != want {
t.Errorf("w.Code = %d; want %d", got, want)
}

var st statuspb.Status
if err := marshaler.Unmarshal(w.Body.Bytes(), &st); err != nil {
if err := spec.marshaler.Unmarshal(w.Body.Bytes(), &st); err != nil {
t.Errorf("marshaler.Unmarshal(%q, &body) failed with %v; want success", w.Body.Bytes(), err)
return
}
Expand Down
28 changes: 20 additions & 8 deletions runtime/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,20 +56,27 @@ func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshal
return
}

respRw, err := mux.forwardResponseRewriter(ctx, resp)
if err != nil {
grpclog.Errorf("Rewrite error: %v", err)
handleForwardResponseStreamError(ctx, wroteHeader, marshaler, w, req, mux, err, delimiter)
return
}

if !wroteHeader {
w.Header().Set("Content-Type", marshaler.ContentType(resp))
w.Header().Set("Content-Type", marshaler.ContentType(respRw))
}

var buf []byte
httpBody, isHTTPBody := resp.(*httpbody.HttpBody)
httpBody, isHTTPBody := respRw.(*httpbody.HttpBody)
switch {
case resp == nil:
case respRw == nil:
buf, err = marshaler.Marshal(errorChunk(status.New(codes.Internal, "empty response")))
case isHTTPBody:
buf = httpBody.GetData()
default:
result := map[string]interface{}{"result": resp}
if rb, ok := resp.(responseBody); ok {
result := map[string]interface{}{"result": respRw}
if rb, ok := respRw.(responseBody); ok {
result["result"] = rb.XXX_ResponseBody()
}

Expand Down Expand Up @@ -165,12 +172,17 @@ func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marsha
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
respRw, err := mux.forwardResponseRewriter(ctx, resp)
if err != nil {
grpclog.Errorf("Rewrite error: %v", err)
HTTPError(ctx, mux, marshaler, w, req, err)
return
}
var buf []byte
var err error
if rb, ok := resp.(responseBody); ok {
if rb, ok := respRw.(responseBody); ok {
buf, err = marshaler.Marshal(rb.XXX_ResponseBody())
} else {
buf, err = marshaler.Marshal(resp)
buf, err = marshaler.Marshal(respRw)
}
if err != nil {
grpclog.Errorf("Marshal error: %v", err)
Expand Down
Loading
Loading