Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit the max request size to be accepted #780

Merged
merged 10 commits into from
Jul 30, 2016
3 changes: 2 additions & 1 deletion call.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package grpc
import (
"bytes"
"io"
"math"
"time"

"golang.org/x/net/context"
Expand Down Expand Up @@ -64,7 +65,7 @@ func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, s
}
p := &parser{r: stream}
for {
if err = recv(p, dopts.codec, stream, dopts.dc, reply); err != nil {
if err = recv(p, dopts.codec, stream, dopts.dc, reply, math.MaxInt32); err != nil {
if err == io.EOF {
break
}
Expand Down
2 changes: 1 addition & 1 deletion call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ type testStreamHandler struct {
func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
p := &parser{r: s}
for {
pf, req, err := p.recvMsg()
pf, req, err := p.recvMsg(math.MaxInt32)
if err == io.EOF {
break
}
Expand Down
18 changes: 13 additions & 5 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ type parser struct {
// No other error values or types must be returned, which also means
// that the underlying io.Reader must not return an incompatible
// error.
func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
func (p *parser) recvMsg(maxMsgSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := io.ReadFull(p.r, p.header[:]); err != nil {
return 0, nil, err
}
Expand All @@ -238,6 +238,9 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {
if length == 0 {
return pf, nil, nil
}
if length > uint32(maxMsgSize) {
return 0, nil, Errorf(codes.Internal, "grpc: received message length %d exceeding the max size %d", length, maxMsgSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
Expand Down Expand Up @@ -308,8 +311,8 @@ func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) er
return nil
}

func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}) error {
pf, d, err := p.recvMsg()
func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{}, maxMsgSize int) error {
pf, d, err := p.recvMsg(maxMsgSize)
if err != nil {
return err
}
Expand All @@ -319,11 +322,16 @@ func recv(p *parser, c Codec, s *transport.Stream, dc Decompressor, m interface{
if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d))
if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
return Errorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
}
if len(d) > maxMsgSize {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this validation be done before decompressing the message?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

additionally, it seems the message size validation is also being done in the parser: https://github.com/grpc/grpc-go/blob/master/rpc_util.go#L253

// TODO: Revisit the error code. Currently keep it consistent with java
// implementation.
return Errorf(codes.Internal, "grpc: received a message of %d bytes exceeding %d limit", len(d), maxMsgSize)
}
if err := c.Unmarshal(d, m); err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
return Errorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
}
return nil
}
Expand Down
13 changes: 7 additions & 6 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ package grpc
import (
"bytes"
"io"
"math"
"reflect"
"testing"

Expand Down Expand Up @@ -66,9 +67,9 @@ func TestSimpleParsing(t *testing.T) {
} {
buf := bytes.NewReader(test.p)
parser := &parser{r: buf}
pt, b, err := parser.recvMsg()
pt, b, err := parser.recvMsg(math.MaxInt32)
if err != test.err || !bytes.Equal(b, test.b) || pt != test.pt {
t.Fatalf("parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
t.Fatalf("parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, %v", test.p, pt, b, err, test.pt, test.b, test.err)
}
}
}
Expand All @@ -88,16 +89,16 @@ func TestMultipleParsing(t *testing.T) {
{compressionNone, []byte("d")},
}
for i, want := range wantRecvs {
pt, data, err := parser.recvMsg()
pt, data, err := parser.recvMsg(math.MaxInt32)
if err != nil || pt != want.pt || !reflect.DeepEqual(data, want.data) {
t.Fatalf("after %d calls, parser{%v}.recvMsg() = %v, %v, %v\nwant %v, %v, <nil>",
t.Fatalf("after %d calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant %v, %v, <nil>",
i, p, pt, data, err, want.pt, want.data)
}
}

pt, data, err := parser.recvMsg()
pt, data, err := parser.recvMsg(math.MaxInt32)
if err != io.EOF {
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg() = %v, %v, %v\nwant _, _, %v",
t.Fatalf("after %d recvMsgs calls, parser{%v}.recvMsg(_) = %v, %v, %v\nwant _, _, %v",
len(wantRecvs), p, pt, data, err, io.EOF)
}
}
Expand Down
43 changes: 33 additions & 10 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,15 @@ type options struct {
codec Codec
cp Compressor
dc Decompressor
maxMsgSize int
unaryInt UnaryServerInterceptor
streamInt StreamServerInterceptor
maxConcurrentStreams uint32
useHandlerImpl bool // use http.Handler-based server
}

var defaultMaxMsgSize = 1024 * 1024 * 4 // use 4MB as the default message size limit

// A ServerOption sets options.
type ServerOption func(*options)

Expand All @@ -121,20 +124,28 @@ func CustomCodec(codec Codec) ServerOption {
}
}

// RPCCompressor returns a ServerOption that sets a compressor for outbound message.
// RPCCompressor returns a ServerOption that sets a compressor for outbound messages.
func RPCCompressor(cp Compressor) ServerOption {
return func(o *options) {
o.cp = cp
}
}

// RPCDecompressor returns a ServerOption that sets a decompressor for inbound message.
// RPCDecompressor returns a ServerOption that sets a decompressor for inbound messages.
func RPCDecompressor(dc Decompressor) ServerOption {
return func(o *options) {
o.dc = dc
}
}

// MaxMsgSize returns a ServerOption to set the max message size in bytes for inbound mesages.
// If this is not set, gRPC uses the default 4MB.
func MaxMsgSize(m int) ServerOption {
return func(o *options) {
o.maxMsgSize = m
}
}

// MaxConcurrentStreams returns a ServerOption that will apply a limit on the number
// of concurrent streams to each ServerTransport.
func MaxConcurrentStreams(n uint32) ServerOption {
Expand Down Expand Up @@ -177,6 +188,7 @@ func StreamInterceptor(i StreamServerInterceptor) ServerOption {
// started to accept requests yet.
func NewServer(opt ...ServerOption) *Server {
var opts options
opts.maxMsgSize = defaultMaxMsgSize
for _, o := range opt {
o(&opts)
}
Expand Down Expand Up @@ -526,7 +538,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
p := &parser{r: stream}
for {
pf, req, err := p.recvMsg()
pf, req, err := p.recvMsg(s.opts.maxMsgSize)
if err == io.EOF {
// The entire stream is done (for unary RPC only).
return err
Expand All @@ -536,6 +548,10 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
}
if err != nil {
switch err := err.(type) {
case *rpcError:
if err := t.WriteStatus(stream, err.code, err.desc); err != nil {
grpclog.Printf("grpc: Server.processUnaryRPC failed to write status %v", err)
}
case transport.ConnectionError:
// Nothing to do here.
case transport.StreamError:
Expand Down Expand Up @@ -575,6 +591,12 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport.
return err
}
}
if len(req) > s.opts.maxMsgSize {
// TODO: Revisit the error code. Currently keep it consistent with
// java implementation.
statusCode = codes.Internal
statusDesc = fmt.Sprintf("grpc: server received a message of %d bytes exceeding %d limit", len(req), s.opts.maxMsgSize)
}
if err := s.opts.codec.Unmarshal(req, v); err != nil {
return err
}
Expand Down Expand Up @@ -634,13 +656,14 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp
stream.SetSendCompress(s.opts.cp.Type())
}
ss := &serverStream{
t: t,
s: stream,
p: &parser{r: stream},
codec: s.opts.codec,
cp: s.opts.cp,
dc: s.opts.dc,
trInfo: trInfo,
t: t,
s: stream,
p: &parser{r: stream},
codec: s.opts.codec,
cp: s.opts.cp,
dc: s.opts.dc,
maxMsgSize: s.opts.maxMsgSize,
trInfo: trInfo,
}
if ss.cp != nil {
ss.cbuf = new(bytes.Buffer)
Expand Down
8 changes: 5 additions & 3 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
"bytes"
"errors"
"io"
"math"
"sync"
"time"

Expand Down Expand Up @@ -291,7 +292,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) {
}

func (cs *clientStream) RecvMsg(m interface{}) (err error) {
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
defer func() {
// err != nil indicates the termination of the stream.
if err != nil {
Expand All @@ -310,7 +311,7 @@ func (cs *clientStream) RecvMsg(m interface{}) (err error) {
return
}
// Special handling for client streaming rpc.
err = recv(cs.p, cs.codec, cs.s, cs.dc, m)
err = recv(cs.p, cs.codec, cs.s, cs.dc, m, math.MaxInt32)
cs.closeTransportStream(err)
if err == nil {
return toRPCErr(errors.New("grpc: client streaming protocol violation: get <nil>, want <EOF>"))
Expand Down Expand Up @@ -411,6 +412,7 @@ type serverStream struct {
cp Compressor
dc Decompressor
cbuf *bytes.Buffer
maxMsgSize int
statusCode codes.Code
statusDesc string
trInfo *traceInfo
Expand Down Expand Up @@ -477,5 +479,5 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) {
ss.mu.Unlock()
}
}()
return recv(ss.p, ss.codec, ss.s, ss.dc, m)
return recv(ss.p, ss.codec, ss.s, ss.dc, m, ss.maxMsgSize)
}
63 changes: 63 additions & 0 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ type test struct {
testServer testpb.TestServiceServer // nil means none
healthServer *health.Server // nil means disabled
maxStream uint32
maxMsgSize int
userAgent string
clientCompression bool
serverCompression bool
Expand Down Expand Up @@ -413,6 +414,9 @@ func (te *test) startServer(ts testpb.TestServiceServer) {
e := te.e
te.t.Logf("Running test in %s environment...", e.name)
sopts := []grpc.ServerOption{grpc.MaxConcurrentStreams(te.maxStream)}
if te.maxMsgSize > 0 {
sopts = append(sopts, grpc.MaxMsgSize(te.maxMsgSize))
}
if te.serverCompression {
sopts = append(sopts,
grpc.RPCCompressor(grpc.NewGZIPCompressor()),
Expand Down Expand Up @@ -1068,6 +1072,65 @@ func testLargeUnary(t *testing.T, e env) {
}
}

func TestExceedMsgLimit(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
testExceedMsgLimit(t, e)
}
}

func testExceedMsgLimit(t *testing.T, e env) {
te := newTest(t, e)
te.maxMsgSize = 1024
te.startServer(&testServer{security: e.security})
defer te.tearDown()
tc := testpb.NewTestServiceClient(te.clientConn())

argSize := int32(te.maxMsgSize + 1)
const respSize = 1

payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, argSize)
if err != nil {
t.Fatal(err)
}

req := &testpb.SimpleRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseSize: proto.Int32(respSize),
Payload: payload,
}
if _, err := tc.UnaryCall(context.Background(), req); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("TestService/UnaryCall(_, _) = _, %v, want _, error code: %d", err, codes.Internal)
}

stream, err := tc.FullDuplexCall(te.ctx)
if err != nil {
t.Fatalf("%v.FullDuplexCall(_) = _, %v, want <nil>", tc, err)
}
respParam := []*testpb.ResponseParameters{
{
Size: proto.Int32(1),
},
}

spayload, err := newPayload(testpb.PayloadType_COMPRESSABLE, int32(te.maxMsgSize+1))
if err != nil {
t.Fatal(err)
}

sreq := &testpb.StreamingOutputCallRequest{
ResponseType: testpb.PayloadType_COMPRESSABLE.Enum(),
ResponseParameters: respParam,
Payload: spayload,
}
if err := stream.Send(sreq); err != nil {
t.Fatalf("%v.Send(%v) = %v, want <nil>", stream, sreq, err)
}
if _, err := stream.Recv(); err == nil || grpc.Code(err) != codes.Internal {
t.Fatalf("%v.Recv() = _, %v, want _, error code: %d", stream, err, codes.Internal)
}
}

func TestMetadataUnaryRPC(t *testing.T) {
defer leakCheck(t)()
for _, e := range listTestEnv() {
Expand Down