Skip to content

Commit

Permalink
Merge pull request #500 from iamqizhao/master
Browse files Browse the repository at this point in the history
Support connection level compression
  • Loading branch information
iamqizhao committed Jan 25, 2016
2 parents 5da22b9 + 8ced3f9 commit e29d659
Show file tree
Hide file tree
Showing 12 changed files with 527 additions and 122 deletions.
25 changes: 18 additions & 7 deletions call.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
package grpc

import (
"bytes"
"io"
"time"

Expand All @@ -47,7 +48,7 @@ import (
// On error, it returns the error and indicates whether the call should be retried.
//
// TODO(zhaoq): Check whether the received message sequence is valid.
func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
func recvResponse(dopts dialOptions, t transport.ClientTransport, c *callInfo, stream *transport.Stream, reply interface{}) error {
// Try to acquire header metadata from the server if there is any.
var err error
c.headerMD, err = stream.Header()
Expand All @@ -56,7 +57,7 @@ func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream
}
p := &parser{s: stream}
for {
if err = recv(p, codec, reply); err != nil {
if err = recv(p, dopts.codec, stream, dopts.dg, reply); err != nil {
if err == io.EOF {
break
}
Expand All @@ -68,7 +69,7 @@ func recvResponse(codec Codec, t transport.ClientTransport, c *callInfo, stream
}

// sendRequest writes out various information of an RPC such as Context and Message.
func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
func sendRequest(ctx context.Context, codec Codec, compressor Compressor, callHdr *transport.CallHdr, t transport.ClientTransport, args interface{}, opts *transport.Options) (_ *transport.Stream, err error) {
stream, err := t.NewStream(ctx, callHdr)
if err != nil {
return nil, err
Expand All @@ -80,8 +81,11 @@ func sendRequest(ctx context.Context, codec Codec, callHdr *transport.CallHdr, t
}
}
}()
// TODO(zhaoq): Support compression.
outBuf, err := encode(codec, args, compressionNone)
var cbuf *bytes.Buffer
if compressor != nil {
cbuf = new(bytes.Buffer)
}
outBuf, err := encode(codec, args, compressor, cbuf)
if err != nil {
return nil, transport.StreamErrorf(codes.Internal, "grpc: %v", err)
}
Expand Down Expand Up @@ -129,7 +133,11 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
}
var (
lastErr error // record the error that happened
cp Compressor
)
if cc.dopts.cg != nil {
cp = cc.dopts.cg()
}
for {
var (
err error
Expand All @@ -144,6 +152,9 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
Host: cc.authority,
Method: method,
}
if cp != nil {
callHdr.SendCompress = cp.Type()
}
t, err = cc.dopts.picker.Pick(ctx)
if err != nil {
if lastErr != nil {
Expand All @@ -155,7 +166,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
if c.traceInfo.tr != nil {
c.traceInfo.tr.LazyLog(&payload{sent: true, msg: args}, true)
}
stream, err = sendRequest(ctx, cc.dopts.codec, callHdr, t, args, topts)
stream, err = sendRequest(ctx, cc.dopts.codec, cp, callHdr, t, args, topts)
if err != nil {
if _, ok := err.(transport.ConnectionError); ok {
lastErr = err
Expand All @@ -167,7 +178,7 @@ func Invoke(ctx context.Context, method string, args, reply interface{}, cc *Cli
return toRPCErr(err)
}
// Receive the response
lastErr = recvResponse(cc.dopts.codec, t, &c, stream, reply)
lastErr = recvResponse(cc.dopts, t, &c, stream, reply)
if _, ok := lastErr.(transport.ConnectionError); ok {
continue
}
Expand Down
2 changes: 1 addition & 1 deletion call_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) {
}
}
// send a response back to end the stream.
reply, err := encode(testCodec{}, &expectedResponse, compressionNone)
reply, err := encode(testCodec{}, &expectedResponse, nil, nil)
if err != nil {
t.Fatalf("Failed to encode the response: %v", err)
}
Expand Down
18 changes: 18 additions & 0 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ var (
// values passed to Dial.
type dialOptions struct {
codec Codec
cg CompressorGenerator
dg DecompressorGenerator
picker Picker
block bool
insecure bool
Expand All @@ -89,6 +91,22 @@ func WithCodec(c Codec) DialOption {
}
}

// WithCompressor returns a DialOption which sets a CompressorGenerator for generating message
// compressor.
func WithCompressor(f CompressorGenerator) DialOption {
return func(o *dialOptions) {
o.cg = f
}
}

// WithDecompressor returns a DialOption which sets a DecompressorGenerator for generating
// message decompressor.
func WithDecompressor(f DecompressorGenerator) DialOption {
return func(o *dialOptions) {
o.dg = f
}
}

// WithPicker returns a DialOption which sets a picker for connection selection.
func WithPicker(p Picker) DialOption {
return func(o *dialOptions) {
Expand Down
125 changes: 110 additions & 15 deletions rpc_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@
package grpc

import (
"bytes"
"compress/gzip"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"math"
"math/rand"
"os"
Expand Down Expand Up @@ -75,6 +78,69 @@ func (protoCodec) String() string {
return "proto"
}

// Compressor defines the interface gRPC uses to compress a message.
type Compressor interface {
// Do compresses p into w.
Do(w io.Writer, p []byte) error
// Type returns the compression algorithm the Compressor uses.
Type() string
}

// NewGZIPCompressor creates a Compressor based on GZIP.
func NewGZIPCompressor() Compressor {
return &gzipCompressor{}
}

type gzipCompressor struct {
}

func (c *gzipCompressor) Do(w io.Writer, p []byte) error {
z := gzip.NewWriter(w)
if _, err := z.Write(p); err != nil {
return err
}
return z.Close()
}

func (c *gzipCompressor) Type() string {
return "gzip"
}

// Decompressor defines the interface gRPC uses to decompress a message.
type Decompressor interface {
// Do reads the data from r and uncompress them.
Do(r io.Reader) ([]byte, error)
// Type returns the compression algorithm the Decompressor uses.
Type() string
}

type gzipDecompressor struct {
}

// NewGZIPDecompressor creates a Decompressor based on GZIP.
func NewGZIPDecompressor() Decompressor {
return &gzipDecompressor{}
}

func (d *gzipDecompressor) Do(r io.Reader) ([]byte, error) {
z, err := gzip.NewReader(r)
if err != nil {
return nil, err
}
defer z.Close()
return ioutil.ReadAll(z)
}

func (d *gzipDecompressor) Type() string {
return "gzip"
}

// CompressorGenerator defines the function generating a Compressor.
type CompressorGenerator func() Compressor

// DecompressorGenerator defines the function generating a Decompressor.
type DecompressorGenerator func() Decompressor

// callInfo contains all related configuration and information about an RPC.
type callInfo struct {
failFast bool
Expand Down Expand Up @@ -126,8 +192,7 @@ type payloadFormat uint8

const (
compressionNone payloadFormat = iota // no compression
compressionFlate
// More formats
compressionMade
)

// parser reads complelete gRPC messages from the underlying reader.
Expand Down Expand Up @@ -166,7 +231,7 @@ func (p *parser) recvMsg() (pf payloadFormat, msg []byte, err error) {

// encode serializes msg and prepends the message header. If msg is nil, it
// generates the message header of 0 message length.
func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
func encode(c Codec, msg interface{}, cp Compressor, cbuf *bytes.Buffer) ([]byte, error) {
var b []byte
var length uint
if msg != nil {
Expand All @@ -176,6 +241,12 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
if err != nil {
return nil, err
}
if cp != nil {
if err := cp.Do(cbuf, b); err != nil {
return nil, err
}
b = cbuf.Bytes()
}
length = uint(len(b))
}
if length > math.MaxUint32 {
Expand All @@ -190,7 +261,11 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
var buf = make([]byte, payloadLen+sizeLen+len(b))

// Write payload format
buf[0] = byte(pf)
if cp == nil {
buf[0] = byte(compressionNone)
} else {
buf[0] = byte(compressionMade)
}
// Write length of b into buf
binary.BigEndian.PutUint32(buf[1:], uint32(length))
// Copy encoded msg to buf
Expand All @@ -199,22 +274,42 @@ func encode(c Codec, msg interface{}, pf payloadFormat) ([]byte, error) {
return buf, nil
}

func recv(p *parser, c Codec, m interface{}) error {
func checkRecvPayload(pf payloadFormat, recvCompress string, dc Decompressor) error {
switch pf {
case compressionNone:
case compressionMade:
if recvCompress == "" {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
}
if dc == nil || recvCompress != dc.Type() {
return transport.StreamErrorf(codes.InvalidArgument, "grpc: Decompressor is not installed for grpc-encoding %q", recvCompress)
}
default:
return transport.StreamErrorf(codes.InvalidArgument, "grpc: received unexpected payload format %d", pf)
}
return nil
}

func recv(p *parser, c Codec, s *transport.Stream, dg DecompressorGenerator, m interface{}) error {
pf, d, err := p.recvMsg()
if err != nil {
return err
}
switch pf {
case compressionNone:
if err := c.Unmarshal(d, m); err != nil {
if rErr, ok := err.(rpcError); ok {
return rErr
} else {
return Errorf(codes.Internal, "grpc: %v", err)
}
var dc Decompressor
if pf == compressionMade && dg != nil {
dc = dg()
}
if err := checkRecvPayload(pf, s.RecvCompress(), dc); err != nil {
return err
}
if pf == compressionMade {
d, err = dc.Do(bytes.NewReader(d))
if err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to decompress the received message %v", err)
}
default:
return Errorf(codes.Internal, "gprc: compression is not supported yet.")
}
if err := c.Unmarshal(d, m); err != nil {
return transport.StreamErrorf(codes.Internal, "grpc: failed to unmarshal the received message %v", err)
}
return nil
}
Expand Down
36 changes: 30 additions & 6 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,40 @@ func TestEncode(t *testing.T) {
for _, test := range []struct {
// input
msg proto.Message
pt payloadFormat
cp Compressor
// outputs
b []byte
err error
}{
{nil, compressionNone, []byte{0, 0, 0, 0, 0}, nil},
{nil, nil, []byte{0, 0, 0, 0, 0}, nil},
} {
b, err := encode(protoCodec{}, test.msg, test.pt)
b, err := encode(protoCodec{}, test.msg, nil, nil)
if err != test.err || !bytes.Equal(b, test.b) {
t.Fatalf("encode(_, _, %d) = %v, %v\nwant %v, %v", test.pt, b, err, test.b, test.err)
t.Fatalf("encode(_, _, %v, _) = %v, %v\nwant %v, %v", test.cp, b, err, test.b, test.err)
}
}
}

func TestCompress(t *testing.T) {
for _, test := range []struct {
// input
data []byte
cp Compressor
dc Decompressor
// outputs
err error
}{
{make([]byte, 1024), &gzipCompressor{}, &gzipDecompressor{}, nil},
} {
b := new(bytes.Buffer)
if err := test.cp.Do(b, test.data); err != test.err {
t.Fatalf("Compressor.Do(_, %v) = %v, want %v", test.data, err, test.err)
}
if b.Len() >= len(test.data) {
t.Fatalf("The compressor fails to compress data.")
}
if p, err := test.dc.Do(b); err != nil || !bytes.Equal(test.data, p) {
t.Fatalf("Decompressor.Do(%v) = %v, %v, want %v, <nil>", b, p, err, test.data)
}
}
}
Expand Down Expand Up @@ -158,12 +182,12 @@ func TestContextErr(t *testing.T) {
// bytes.
func bmEncode(b *testing.B, mSize int) {
msg := &perfpb.Buffer{Body: make([]byte, mSize)}
encoded, _ := encode(protoCodec{}, msg, compressionNone)
encoded, _ := encode(protoCodec{}, msg, nil, nil)
encodedSz := int64(len(encoded))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
encode(protoCodec{}, msg, compressionNone)
encode(protoCodec{}, msg, nil, nil)
}
b.SetBytes(encodedSz)
}
Expand Down
Loading

0 comments on commit e29d659

Please sign in to comment.