Skip to content

Commit

Permalink
Client should have a check on maximum size of received message size.
Browse files Browse the repository at this point in the history
  • Loading branch information
MakMukhi committed Mar 10, 2017
1 parent 4eaacfe commit a94b094
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 38 deletions.
3 changes: 1 addition & 2 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ package grpc
import (
"bytes"
"io"
"math"
"time"

"golang.org/x/net/context"
Expand Down Expand Up @@ -73,7 +72,7 @@ func recvResponse(ctx context.Context, dopts dialOptions, t transport.ClientTran
}
}
for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32, inPayload); err != nil {
if err = recv(p, dopts.codec, stream, dopts.dc, reply, dopts.maxMsgSize, inPayload); err != nil {
if err == io.EOF {
break
}
Expand Down
40 changes: 27 additions & 13 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package grpc
import (
"errors"
"fmt"
"math"
"net"
"strings"
"sync"
Expand Down Expand Up @@ -87,23 +88,33 @@ var (
// dialOptions configure a Dial call. dialOptions are set by the DialOption
// values passed to Dial.
type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
balancer Balancer
block bool
insecure bool
timeout time.Duration
scChan <-chan ServiceConfig
copts transport.ConnectOptions
}
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
codec Codec
cp Compressor
dc Decompressor
bs backoffStrategy
balancer Balancer
block bool
insecure bool
timeout time.Duration
scChan <-chan ServiceConfig
copts transport.ConnectOptions
maxMsgSize int
}

const defaultClientMaxMsgSize = math.MaxInt32

// DialOption configures how we set up the connection.
type DialOption func(*dialOptions)

// WithMaxMsgSize returns a DialOption which sets the maximum message size the client can receive.
func WithMaxMsgSize(s int) DialOption {
return func(o *dialOptions) {
o.maxMsgSize = s
}
}

// WithCodec returns a DialOption which sets a codec for message marshaling and unmarshaling.
func WithCodec(c Codec) DialOption {
return func(o *dialOptions) {
Expand Down Expand Up @@ -304,6 +315,9 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *
ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout)
defer cancel()
}
if cc.dopts.maxMsgSize == 0 {
cc.dopts.maxMsgSize = defaultClientMaxMsgSize
}

defer func() {
select {
Expand Down
43 changes: 22 additions & 21 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ import (
"bytes"
"errors"
"io"
"math"
"sync"
"time"

Expand Down Expand Up @@ -208,13 +207,14 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth
break
}
cs := &clientStream{
opts: opts,
c: c,
desc: desc,
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
cancel: cancel,
opts: opts,
c: c,
desc: desc,
codec: cc.dopts.codec,
cp: cc.dopts.cp,
dc: cc.dopts.dc,
maxMsgSize: cc.dopts.maxMsgSize,
cancel: cancel,

put: put,
t: t,
Expand Down Expand Up @@ -259,17 +259,18 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth

// clientStream implements a client side Stream.
type clientStream struct {
opts []CallOption
c callInfo
t transport.ClientTransport
s *transport.Stream
p *parser
desc *StreamDesc
codec Codec
cp Compressor
cbuf *bytes.Buffer
dc Decompressor
cancel context.CancelFunc
opts []CallOption
c callInfo
t transport.ClientTransport
s *transport.Stream
p *parser
desc *StreamDesc
codec Codec
cp Compressor
cbuf *bytes.Buffer
dc Decompressor
maxMsgSize int
cancel context.CancelFunc

tracing bool // set to EnableTracing when the clientStream is created.

Expand Down Expand Up @@ -382,7 +383,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
Client: true,
}
}
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, inPayload)
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, inPayload)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
Expand All @@ -405,7 +406,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
}
// Special handling for client streaming rpc.
// This recv expects EOF or errors, so we don't collect inPayload.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32, nil)
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, cs.maxMsgSize, nil)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
Expand Down
33 changes: 31 additions & 2 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,9 @@ func (te *test) clientConn() *grpc.ClientConn {
if te.streamClientInt != nil {
opts = append(opts, grpc.WithStreamInterceptor(te.streamClientInt))
}
if te.maxMsgSize > 0 {
opts = append(opts, grpc.WithMaxMsgSize(te.maxMsgSize))
}
switch te.e.security {
case "tls":
creds, err := credentials.NewClientTLSFromFile(tlsDir+"ca.pem", "x.test.youtube.com")
Expand Down Expand Up @@ -1427,22 +1430,33 @@ func testExceedMsgLimit(t *testing.T, e env) {
tc := testpb.NewTestServiceClient(te.clientConn())

argSize := int32(te.maxMsgSize + 1)
const respSize = 1
const smallSize = 1

payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}
smallPayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, smallSize)
if err != nil {
t.Fatal(err)
}

// test on server side for unary RPC
req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(respSize),
ResponseSize: proto.Int32(smallSize),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
}
// test on client side for unary RPC
req.ResponseSize = proto.Int32(int32(te.maxMsgSize) + 1)
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %s", err, codes.Internal)
}

// test on server side for streaming RPC
stream, err := tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
Expand All @@ -1469,6 +1483,21 @@ func testExceedMsgLimit(t *testing.T, e env) {
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
}

// test on client side for streaming RPC
stream, err = tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam[0].Size = proto.Int32(int32(te.maxMsgSize) + 1)
sreq.Payload = smallPayload
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("%v.Recv() = _, %v, want _, error code: %s", stream, err, codes.Internal)
}

}

func TestPeerClientSide(t *testing.T) {
Expand Down

0 comments on commit a94b094

Please sign in to comment.