From 67bfd7133d7e088764cdef50412e063e26f94f92 Mon Sep 17 00:00:00 2001 From: Yuki Yugui Sonoda Date: Mon, 25 Jul 2016 21:32:28 +0900 Subject: [PATCH] Avoid Internal Server Error on zero-length input for bidi streaming Fixes #195 --- examples/examplepb/flow_combination.pb.gw.go | 7 +- examples/examplepb/stream.pb.gw.go | 7 +- examples/integration_test.go | 114 ++++++++++++++++++ .../gengateway/template.go | 7 +- 4 files changed, 129 insertions(+), 6 deletions(-) diff --git a/examples/examplepb/flow_combination.pb.gw.go b/examples/examplepb/flow_combination.pb.gw.go index 2845b8dbdbc..af9da8d5f95 100644 --- a/examples/examplepb/flow_combination.pb.gw.go +++ b/examples/examplepb/flow_combination.pb.gw.go @@ -119,8 +119,11 @@ func request_FlowCombination_StreamEmptyStream_0(ctx context.Context, marshaler return nil } if err := handleSend(); err != nil { - if err := stream.CloseSend(); err != nil { - grpclog.Printf("Failed to terminate client stream: %v", err) + if cerr := stream.CloseSend(); cerr != nil { + grpclog.Printf("Failed to terminate client stream: %v", cerr) + } + if err == io.EOF { + return stream, metadata, nil } return nil, metadata, err } diff --git a/examples/examplepb/stream.pb.gw.go b/examples/examplepb/stream.pb.gw.go index ae24644cadf..f256228759a 100644 --- a/examples/examplepb/stream.pb.gw.go +++ b/examples/examplepb/stream.pb.gw.go @@ -112,8 +112,11 @@ func request_StreamService_BulkEcho_0(ctx context.Context, marshaler runtime.Mar return nil } if err := handleSend(); err != nil { - if err := stream.CloseSend(); err != nil { - grpclog.Printf("Failed to terminate client stream: %v", err) + if cerr := stream.CloseSend(); cerr != nil { + grpclog.Printf("Failed to terminate client stream: %v", cerr) + } + if err == io.EOF { + return stream, metadata, nil } return nil, metadata, err } diff --git a/examples/integration_test.go b/examples/integration_test.go index 2d345540cfe..ba011497b6a 100644 --- a/examples/integration_test.go +++ b/examples/integration_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "encoding/json" "fmt" "io" @@ -9,6 +10,7 @@ import ( "reflect" "strconv" "strings" + "sync" "testing" "time" @@ -152,6 +154,8 @@ func TestABE(t *testing.T) { testABELookup(t) testABELookupNotFound(t) testABEList(t) + testABEBulkEcho(t) + testABEBulkEchoZeroLength(t) testAdditionalBindings(t) } @@ -527,6 +531,116 @@ func testABEList(t *testing.T) { } } +func testABEBulkEcho(t *testing.T) { + reqr, reqw := io.Pipe() + var wg sync.WaitGroup + var want []*sub.StringMessage + wg.Add(1) + go func() { + defer wg.Done() + defer reqw.Close() + var m jsonpb.Marshaler + for i := 0; i < 1000; i++ { + msg := sub.StringMessage{Value: proto.String(fmt.Sprintf("message %d", i))} + buf, err := m.MarshalToString(&msg) + if err != nil { + t.Errorf("m.Marshal(%v) failed with %v; want success", &msg, err) + return + } + if _, err := fmt.Fprintln(reqw, buf); err != nil { + t.Errorf("fmt.Fprintln(reqw, %q) failed with %v; want success", buf, err) + return + } + want = append(want, &msg) + } + }() + + url := "http://localhost:8080/v1/example/a_bit_of_everything/echo" + req, err := http.NewRequest("POST", url, reqr) + if err != nil { + t.Errorf("http.NewRequest(%q, %q, reqr) failed with %v; want success", "POST", url, err) + return + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Transfer-Encoding", "chunked") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("http.Post(%q, %q, req) failed with %v; want success", url, "application/json", err) + return + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("resp.StatusCode = %d; want %d", got, want) + } + + var got []*sub.StringMessage + wg.Add(1) + go func() { + defer wg.Done() + + dec := json.NewDecoder(resp.Body) + for i := 0; ; i++ { + var item struct { + Result json.RawMessage `json:"result"` + Error map[string]interface{} `json:"error"` + } + err := dec.Decode(&item) + if err == io.EOF { + break + } + if err != nil { + t.Errorf("dec.Decode(&item) failed with %v; want success; i = %d", err, i) + } + if len(item.Error) != 0 { + t.Errorf("item.Error = %#v; want empty; i = %d", item.Error, i) + continue + } + var msg sub.StringMessage + if err := jsonpb.UnmarshalString(string(item.Result), &msg); err != nil { + t.Errorf("jsonpb.UnmarshalString(%q, &msg) failed with %v; want success", item.Result, err) + } + got = append(got, &msg) + } + }() + + wg.Wait() + if !reflect.DeepEqual(got, want) { + t.Errorf("got = %v; want %v", got, want) + } +} + +func testABEBulkEchoZeroLength(t *testing.T) { + url := "http://localhost:8080/v1/example/a_bit_of_everything/echo" + req, err := http.NewRequest("POST", url, bytes.NewReader(nil)) + if err != nil { + t.Errorf("http.NewRequest(%q, %q, bytes.NewReader(nil)) failed with %v; want success", "POST", url, err) + return + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Transfer-Encoding", "chunked") + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Errorf("http.Post(%q, %q, req) failed with %v; want success", url, "application/json", err) + return + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("resp.StatusCode = %d; want %d", got, want) + } + + dec := json.NewDecoder(resp.Body) + var item struct { + Result json.RawMessage `json:"result"` + Error map[string]interface{} `json:"error"` + } + if err := dec.Decode(&item); err == nil { + t.Errorf("dec.Decode(&item) succeeded; want io.EOF; item = %#v", item) + } else if err != io.EOF { + t.Errorf("dec.Decode(&item) failed with %v; want success", err) + return + } +} + func testAdditionalBindings(t *testing.T) { for i, f := range []func() *http.Response{ func() *http.Response { diff --git a/protoc-gen-grpc-gateway/gengateway/template.go b/protoc-gen-grpc-gateway/gengateway/template.go index a5781dd7689..3c3da539b95 100644 --- a/protoc-gen-grpc-gateway/gengateway/template.go +++ b/protoc-gen-grpc-gateway/gengateway/template.go @@ -262,8 +262,11 @@ var ( return nil } if err := handleSend(); err != nil { - if err := stream.CloseSend(); err != nil { - grpclog.Printf("Failed to terminate client stream: %v", err) + if cerr := stream.CloseSend(); cerr != nil { + grpclog.Printf("Failed to terminate client stream: %v", cerr) + } + if err == io.EOF { + return stream, metadata, nil } return nil, metadata, err }