From 82babd01e56186d42411b4850aaff1e557ab99bc Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 24 Oct 2017 14:42:04 +0100 Subject: [PATCH 01/17] Allow negative values for getArrayLength() Kafka supports nullable arrays, and their null value is represented by legnth of -1. --- real_decoder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/real_decoder.go b/real_decoder.go index 4a05863a9..05bcd207f 100644 --- a/real_decoder.go +++ b/real_decoder.go @@ -79,7 +79,7 @@ func (rd *realDecoder) getArrayLength() (int, error) { rd.off = len(rd.raw) return -1, ErrInsufficientData } - tmp := int(binary.BigEndian.Uint32(rd.raw[rd.off:])) + tmp := int(int32(binary.BigEndian.Uint32(rd.raw[rd.off:]))) rd.off += 4 if tmp > rd.remaining() { rd.off = len(rd.raw) From f554f211f2792d6dbc6e76b6ca88165ef615da3e Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 24 Oct 2017 14:47:13 +0100 Subject: [PATCH 02/17] Add support for Kafka 0.11 Record format. Kafka 0.11 introduces a new Record format that replaces Message from the previous versions. The new format allows for Headers which are key-value pairs of application metadata associated with each message. --- record.go | 149 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 record.go diff --git a/record.go b/record.go new file mode 100644 index 000000000..427aef1c3 --- /dev/null +++ b/record.go @@ -0,0 +1,149 @@ +package sarama + +import "encoding/binary" + +const ( + controlMask = 0x20 +) + +type Header struct { + Key []byte + Value []byte +} + +func (h *Header) encode(pe packetEncoder) error { + if err := pe.putVarintBytes(h.Key); err != nil { + return err + } + return pe.putVarintBytes(h.Value) +} + +func (h *Header) decode(pd packetDecoder) (err error) { + if h.Key, err = pd.getVarintBytes(); err != nil { + return err + } + + if h.Value, err = pd.getVarintBytes(); err != nil { + return err + } + return nil +} + +type Record struct { + Attributes int8 + TimestampDelta int64 + OffsetDelta int64 + Key []byte + Value []byte + Headers []*Header + + lengthComputed bool + length int64 + totalLength int +} + +func (r *Record) encode(pe packetEncoder) error { + if err := r.computeLength(); err != nil { + return err + } + + pe.putVarint(r.length) + pe.putInt8(r.Attributes) + pe.putVarint(r.TimestampDelta) + pe.putVarint(r.OffsetDelta) + if err := pe.putVarintBytes(r.Key); err != nil { + return err + } + if err := pe.putVarintBytes(r.Value); err != nil { + return err + } + pe.putVarint(int64(len(r.Headers))) + + for _, h := range r.Headers { + if err := h.encode(pe); err != nil { + return err + } + } + + return nil +} + +func (r *Record) decode(pd packetDecoder) (err error) { + length, err := newVarintLengthField(pd) + if err != nil { + return err + } + if err = pd.push(length); err != nil { + return err + } + r.length = length.length + r.lengthComputed = true + + if r.Attributes, err = pd.getInt8(); err != nil { + return err + } + + if r.TimestampDelta, err = pd.getVarint(); err != nil { + return err + } + + if r.OffsetDelta, err = pd.getVarint(); err != nil { + return err + } + + if r.Key, err = pd.getVarintBytes(); err != nil { + return err + } + + if r.Value, err = pd.getVarintBytes(); err != nil { + return err + } + + numHeaders, err := pd.getVarint() + if err != nil { + return err + } + + if numHeaders >= 0 { + r.Headers = make([]*Header, numHeaders) + } + for i := int64(0); i < numHeaders; i++ { + hdr := new(Header) + if err := hdr.decode(pd); err != nil { + return err + } + r.Headers[i] = hdr + } + + return pd.pop() +} + +// Because the length is varint we can't reserve a fixed amount of bytes for it. +// We use the prepEncoder to figure out the length of the record and then we cache it. +func (r *Record) computeLength() error { + if !r.lengthComputed { + r.lengthComputed = true + + var prep prepEncoder + if err := r.encode(&prep); err != nil { + return err + } + // subtract 1 because we don't want to include the length field itself (which 1 byte, the + // length of varint encoding of 0) + r.length = int64(prep.length) - 1 + } + + return nil +} + +func (r *Record) getTotalLength() (int, error) { + if r.totalLength == 0 { + if err := r.computeLength(); err != nil { + return 0, err + } + var buf [binary.MaxVarintLen64]byte + r.totalLength = int(r.length) + binary.PutVarint(buf[:], r.length) + } + + return r.totalLength, nil +} From 7630f80a01ff1db4e0e6e91e2def16186ca727f3 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 24 Oct 2017 14:50:19 +0100 Subject: [PATCH 03/17] Add support for Kafka 0.11 RecordBatch Kafka 0.11 introduced RecordBatch as a successor to MessageSet. Using the new RecordBatch is required for transactions and idempotent message delivery. --- record_batch.go | 254 +++++++++++++++++++++++++++++++++++++++++++++ record_test.go | 268 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 522 insertions(+) create mode 100644 record_batch.go create mode 100644 record_test.go diff --git a/record_batch.go b/record_batch.go new file mode 100644 index 000000000..83d6f25c7 --- /dev/null +++ b/record_batch.go @@ -0,0 +1,254 @@ +package sarama + +import ( + "bytes" + "compress/gzip" + "fmt" + "io/ioutil" + + "github.com/eapache/go-xerial-snappy" + "github.com/pierrec/lz4" +) + +const recordBatchOverhead = 49 + +type RecordBatch struct { + FirstOffset int64 + PartitionLeaderEpoch int32 + Version int8 + Codec CompressionCodec + Control bool + LastOffsetDelta int32 + FirstTimestamp int64 + MaxTimestamp int64 + ProducerID int64 + ProducerEpoch int16 + FirstSequence int32 + Records []*Record + PartialTrailingRecord bool + + compressedRecords []byte + recordsLen int +} + +func (b *RecordBatch) encode(pe packetEncoder) error { + if b.Version != 2 { + return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)} + } + pe.putInt64(b.FirstOffset) + pe.push(&lengthField{}) + pe.putInt32(b.PartitionLeaderEpoch) + pe.putInt8(b.Version) + pe.push(newCRC32Field(crcCastagnoli)) + pe.putInt16(b.computeAttributes()) + pe.putInt32(b.LastOffsetDelta) + pe.putInt64(b.FirstTimestamp) + pe.putInt64(b.MaxTimestamp) + pe.putInt64(b.ProducerID) + pe.putInt16(b.ProducerEpoch) + pe.putInt32(b.FirstSequence) + + if err := pe.putArrayLength(len(b.Records)); err != nil { + return err + } + + if b.compressedRecords != nil { + if err := pe.putRawBytes(b.compressedRecords); err != nil { + return err + } + if err := pe.pop(); err != nil { + return err + } + if err := pe.pop(); err != nil { + return err + } + return nil + } + + var re packetEncoder + var raw []byte + + switch b.Codec { + case CompressionNone: + re = pe + case CompressionGZIP, CompressionLZ4, CompressionSnappy: + for _, r := range b.Records { + l, err := r.getTotalLength() + if err != nil { + return err + } + b.recordsLen += l + } + + raw = make([]byte, b.recordsLen) + re = &realEncoder{raw: raw} + default: + return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)} + } + + for _, r := range b.Records { + if err := r.encode(re); err != nil { + return err + } + } + + switch b.Codec { + case CompressionGZIP: + var buf bytes.Buffer + writer := gzip.NewWriter(&buf) + if _, err := writer.Write(raw); err != nil { + return err + } + if err := writer.Close(); err != nil { + return err + } + b.compressedRecords = buf.Bytes() + case CompressionSnappy: + b.compressedRecords = snappy.Encode(raw) + case CompressionLZ4: + var buf bytes.Buffer + writer := lz4.NewWriter(&buf) + if _, err := writer.Write(raw); err != nil { + return err + } + if err := writer.Close(); err != nil { + return err + } + b.compressedRecords = buf.Bytes() + } + if err := pe.putRawBytes(b.compressedRecords); err != nil { + return err + } + + if err := pe.pop(); err != nil { + return err + } + if err := pe.pop(); err != nil { + return err + } + + return nil +} + +func (b *RecordBatch) decode(pd packetDecoder) (err error) { + if b.FirstOffset, err = pd.getInt64(); err != nil { + return err + } + + var batchLen int32 + if batchLen, err = pd.getInt32(); err != nil { + return err + } + + if b.PartitionLeaderEpoch, err = pd.getInt32(); err != nil { + return err + } + + if b.Version, err = pd.getInt8(); err != nil { + return err + } + + if err = pd.push(&crc32Field{polynomial: crcCastagnoli}); err != nil { + return err + } + + var attributes int16 + if attributes, err = pd.getInt16(); err != nil { + return err + } + b.Codec = CompressionCodec(int8(attributes) & compressionCodecMask) + b.Control = attributes&controlMask == controlMask + + if b.LastOffsetDelta, err = pd.getInt32(); err != nil { + return err + } + + if b.FirstTimestamp, err = pd.getInt64(); err != nil { + return err + } + + if b.MaxTimestamp, err = pd.getInt64(); err != nil { + return err + } + + if b.ProducerID, err = pd.getInt64(); err != nil { + return err + } + + if b.ProducerEpoch, err = pd.getInt16(); err != nil { + return err + } + + if b.FirstSequence, err = pd.getInt32(); err != nil { + return err + } + + numRecs, err := pd.getArrayLength() + if err != nil { + return err + } + if numRecs >= 0 { + b.Records = make([]*Record, numRecs) + } + + bufSize := int(batchLen) - recordBatchOverhead + recBuffer, err := pd.getRawBytes(bufSize) + if err != nil { + return err + } + + if err = pd.pop(); err != nil { + return err + } + + switch b.Codec { + case CompressionNone: + case CompressionGZIP: + reader, err := gzip.NewReader(bytes.NewReader(recBuffer)) + if err != nil { + return err + } + if recBuffer, err = ioutil.ReadAll(reader); err != nil { + return err + } + case CompressionSnappy: + if recBuffer, err = snappy.Decode(recBuffer); err != nil { + return err + } + case CompressionLZ4: + reader := lz4.NewReader(bytes.NewReader(recBuffer)) + if recBuffer, err = ioutil.ReadAll(reader); err != nil { + return err + } + default: + return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)} + } + recPd := &realDecoder{raw: recBuffer} + + for i := 0; i < numRecs; i++ { + rec := &Record{} + if err = rec.decode(recPd); err != nil { + if err == ErrInsufficientData { + b.PartialTrailingRecord = true + b.Records = nil + return nil + } + return err + } + b.Records[i] = rec + } + + return nil +} + +func (b *RecordBatch) computeAttributes() int16 { + attr := int16(b.Codec) & int16(compressionCodecMask) + if b.Control { + attr |= controlMask + } + return attr +} + +func (b *RecordBatch) addRecord(r *Record) { + b.Records = append(b.Records, r) +} diff --git a/record_test.go b/record_test.go new file mode 100644 index 000000000..b1258f37f --- /dev/null +++ b/record_test.go @@ -0,0 +1,268 @@ +package sarama + +import ( + "reflect" + "runtime" + "strconv" + "strings" + "testing" + + "github.com/davecgh/go-spew/spew" +) + +var recordBatchTestCases = []struct { + name string + batch RecordBatch + encoded []byte + oldGoEncoded []byte // used in case of gzipped content for go versions prior to 1.8 +}{ + { + name: "empty record", + batch: RecordBatch{Version: 2, Records: []*Record{}}, + encoded: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, // First Offset + 0, 0, 0, 49, // Length + 0, 0, 0, 0, // Partition Leader Epoch + 2, // Version + 89, 95, 183, 221, // CRC + 0, 0, // Attributes + 0, 0, 0, 0, // Last Offset Delta + 0, 0, 0, 0, 0, 0, 0, 0, // First Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Producer ID + 0, 0, // Producer Epoch + 0, 0, 0, 0, // First Sequence + 0, 0, 0, 0, // Number of Records + }, + }, + { + name: "control batch", + batch: RecordBatch{Version: 2, Control: true, Records: []*Record{}}, + encoded: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, // First Offset + 0, 0, 0, 49, // Length + 0, 0, 0, 0, // Partition Leader Epoch + 2, // Version + 81, 46, 67, 217, // CRC + 0, 32, // Attributes + 0, 0, 0, 0, // Last Offset Delta + 0, 0, 0, 0, 0, 0, 0, 0, // First Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Producer ID + 0, 0, // Producer Epoch + 0, 0, 0, 0, // First Sequence + 0, 0, 0, 0, // Number of Records + }, + }, + { + name: "uncompressed record", + batch: RecordBatch{ + Version: 2, + FirstTimestamp: 10, + Records: []*Record{{ + TimestampDelta: 5, + Key: []byte{1, 2, 3, 4}, + Value: []byte{5, 6, 7}, + Headers: []*Header{{ + Key: []byte{8, 9, 10}, + Value: []byte{11, 12}, + }}, + }}, + }, + encoded: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, // First Offset + 0, 0, 0, 70, // Length + 0, 0, 0, 0, // Partition Leader Epoch + 2, // Version + 219, 71, 20, 201, // CRC + 0, 0, // Attributes + 0, 0, 0, 0, // Last Offset Delta + 0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Producer ID + 0, 0, // Producer Epoch + 0, 0, 0, 0, // First Sequence + 0, 0, 0, 1, // Number of Records + 40, // Record Length + 0, // Attributes + 10, // Timestamp Delta + 0, // Offset Delta + 8, // Key Length + 1, 2, 3, 4, + 6, // Value Length + 5, 6, 7, + 2, // Number of Headers + 6, // Header Key Length + 8, 9, 10, // Header Key + 4, // Header Value Length + 11, 12, // Header Value + }, + }, + { + name: "gzipped record", + batch: RecordBatch{ + Version: 2, + Codec: CompressionGZIP, + FirstTimestamp: 10, + Records: []*Record{{ + TimestampDelta: 5, + Key: []byte{1, 2, 3, 4}, + Value: []byte{5, 6, 7}, + Headers: []*Header{{ + Key: []byte{8, 9, 10}, + Value: []byte{11, 12}, + }}, + }}, + }, + encoded: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, // First Offset + 0, 0, 0, 94, // Length + 0, 0, 0, 0, // Partition Leader Epoch + 2, // Version + 15, 156, 184, 78, // CRC + 0, 1, // Attributes + 0, 0, 0, 0, // Last Offset Delta + 0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Producer ID + 0, 0, // Producer Epoch + 0, 0, 0, 0, // First Sequence + 0, 0, 0, 1, // Number of Records + 31, 139, 8, 0, 0, 0, 0, 0, 0, 255, 210, 96, 224, 98, 224, 96, 100, 98, 102, 97, 99, 101, + 99, 103, 98, 227, 224, 228, 98, 225, 230, 1, 4, 0, 0, 255, 255, 173, 201, 88, 103, 21, 0, 0, 0, + }, + oldGoEncoded: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, // First Offset + 0, 0, 0, 94, // Length + 0, 0, 0, 0, // Partition Leader Epoch + 2, // Version + 144, 168, 0, 33, // CRC + 0, 1, // Attributes + 0, 0, 0, 0, // Last Offset Delta + 0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Producer ID + 0, 0, // Producer Epoch + 0, 0, 0, 0, // First Sequence + 0, 0, 0, 1, // Number of Records + 31, 139, 8, 0, 0, 9, 110, 136, 0, 255, 210, 96, 224, 98, 224, 96, 100, 98, 102, 97, 99, 101, + 99, 103, 98, 227, 224, 228, 98, 225, 230, 1, 4, 0, 0, 255, 255, 173, 201, 88, 103, 21, 0, 0, 0, + }, + }, + { + name: "snappy compressed record", + batch: RecordBatch{ + Version: 2, + Codec: CompressionSnappy, + FirstTimestamp: 10, + Records: []*Record{{ + TimestampDelta: 5, + Key: []byte{1, 2, 3, 4}, + Value: []byte{5, 6, 7}, + Headers: []*Header{{ + Key: []byte{8, 9, 10}, + Value: []byte{11, 12}, + }}, + }}, + }, + encoded: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, // First Offset + 0, 0, 0, 72, // Length + 0, 0, 0, 0, // Partition Leader Epoch + 2, // Version + 95, 173, 35, 17, // CRC + 0, 2, // Attributes + 0, 0, 0, 0, // Last Offset Delta + 0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Producer ID + 0, 0, // Producer Epoch + 0, 0, 0, 0, // First Sequence + 0, 0, 0, 1, // Number of Records + 21, 80, 40, 0, 10, 0, 8, 1, 2, 3, 4, 6, 5, 6, 7, 2, 6, 8, 9, 10, 4, 11, 12, + }, + }, + { + name: "lz4 compressed record", + batch: RecordBatch{ + Version: 2, + Codec: CompressionLZ4, + FirstTimestamp: 10, + Records: []*Record{{ + TimestampDelta: 5, + Key: []byte{1, 2, 3, 4}, + Value: []byte{5, 6, 7}, + Headers: []*Header{{ + Key: []byte{8, 9, 10}, + Value: []byte{11, 12}, + }}, + }}, + }, + encoded: []byte{ + 0, 0, 0, 0, 0, 0, 0, 0, // First Offset + 0, 0, 0, 89, // Length + 0, 0, 0, 0, // Partition Leader Epoch + 2, // Version + 129, 238, 43, 82, // CRC + 0, 3, // Attributes + 0, 0, 0, 0, // Last Offset Delta + 0, 0, 0, 0, 0, 0, 0, 10, // First Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Max Timestamp + 0, 0, 0, 0, 0, 0, 0, 0, // Producer ID + 0, 0, // Producer Epoch + 0, 0, 0, 0, // First Sequence + 0, 0, 0, 1, // Number of Records + 4, 34, 77, 24, 100, 112, 185, 21, 0, 0, 128, 40, 0, 10, 0, 8, 1, 2, 3, 4, 6, 5, 6, 7, 2, + 6, 8, 9, 10, 4, 11, 12, 0, 0, 0, 0, 12, 59, 239, 146, + }, + }, +} + +func isOldGo(t *testing.T) bool { + v := strings.Split(runtime.Version()[2:], ".") + if len(v) < 2 { + t.Logf("Can't parse version: %s", runtime.Version()) + return false + } + maj, err := strconv.Atoi(v[0]) + if err != nil { + t.Logf("Can't parse version: %s", runtime.Version()) + return false + } + min, err := strconv.Atoi(v[1]) + if err != nil { + t.Logf("Can't parse version: %s", runtime.Version()) + return false + } + return maj < 1 || (maj == 1 && min < 8) +} + +func TestRecordBatchEncoding(t *testing.T) { + for _, tc := range recordBatchTestCases { + if tc.oldGoEncoded != nil && isOldGo(t) { + testEncodable(t, tc.name, &tc.batch, tc.oldGoEncoded) + } else { + testEncodable(t, tc.name, &tc.batch, tc.encoded) + } + } +} + +func TestRecordBatchDecoding(t *testing.T) { + for _, tc := range recordBatchTestCases { + batch := RecordBatch{} + testDecodable(t, tc.name, &batch, tc.encoded) + for _, r := range batch.Records { + if _, err := r.getTotalLength(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + for _, r := range tc.batch.Records { + if _, err := r.getTotalLength(); err != nil { + t.Fatalf("Unexpected error: %v", err) + } + } + if !reflect.DeepEqual(batch, tc.batch) { + t.Errorf(spew.Sprintf("invalid decode of %s\ngot %+v\nwanted %+v", tc.name, batch, tc.batch)) + } + } +} From e1067e3e2d36f435cb5893e59080613220c2f14b Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 24 Oct 2017 14:55:03 +0100 Subject: [PATCH 04/17] Implement a sum type that can hold RecordBatch or MessageSet Many request/response structures can contain either RecordBatches or MessageSets depending on the version of Kafka the client is talking to. This changeset implements a sum type that makes it more convenient to work with these structures by abstracting away the type of the records. --- records.go | 96 +++++++++++++++++++++++++++++++++++++++++++ records_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 201 insertions(+) create mode 100644 records.go create mode 100644 records_test.go diff --git a/records.go b/records.go new file mode 100644 index 000000000..2b7953a46 --- /dev/null +++ b/records.go @@ -0,0 +1,96 @@ +package sarama + +import "fmt" + +const ( + legacyRecords = iota + defaultRecords +) + +// Records implements a union type containing either a RecordBatch or a legacy MessageSet. +type Records struct { + recordsType int + msgSet *MessageSet + recordBatch *RecordBatch +} + +func newLegacyRecords(msgSet *MessageSet) Records { + return Records{recordsType: legacyRecords, msgSet: msgSet} +} + +func newDefaultRecords(batch *RecordBatch) Records { + return Records{recordsType: defaultRecords, recordBatch: batch} +} + +func (r *Records) encode(pe packetEncoder) error { + switch r.recordsType { + case legacyRecords: + if r.msgSet == nil { + return nil + } + return r.msgSet.encode(pe) + case defaultRecords: + if r.recordBatch == nil { + return nil + } + return r.recordBatch.encode(pe) + } + return fmt.Errorf("unknown records type: %v", r.recordsType) +} + +func (r *Records) decode(pd packetDecoder) error { + switch r.recordsType { + case legacyRecords: + r.msgSet = &MessageSet{} + return r.msgSet.decode(pd) + case defaultRecords: + r.recordBatch = &RecordBatch{} + return r.recordBatch.decode(pd) + } + return fmt.Errorf("unknown records type: %v", r.recordsType) +} + +func (r *Records) numRecords() (int, error) { + switch r.recordsType { + case legacyRecords: + if r.msgSet == nil { + return 0, nil + } + return len(r.msgSet.Messages), nil + case defaultRecords: + if r.recordBatch == nil { + return 0, nil + } + return len(r.recordBatch.Records), nil + } + return 0, fmt.Errorf("unknown records type: %v", r.recordsType) +} + +func (r *Records) isPartial() (bool, error) { + switch r.recordsType { + case legacyRecords: + if r.msgSet == nil { + return false, nil + } + return r.msgSet.PartialTrailingMessage, nil + case defaultRecords: + if r.recordBatch == nil { + return false, nil + } + return r.recordBatch.PartialTrailingRecord, nil + } + return false, fmt.Errorf("unknown records type: %v", r.recordsType) +} + +func (r *Records) isControl() (bool, error) { + switch r.recordsType { + case legacyRecords: + return false, nil + case defaultRecords: + if r.recordBatch == nil { + return false, nil + } + return r.recordBatch.Control, nil + } + return false, fmt.Errorf("unknown records type: %v", r.recordsType) +} diff --git a/records_test.go b/records_test.go new file mode 100644 index 000000000..1b9f04708 --- /dev/null +++ b/records_test.go @@ -0,0 +1,105 @@ +package sarama + +import ( + "bytes" + "reflect" + "testing" +) + +func TestLegacyRecords(t *testing.T) { + set := &MessageSet{ + Messages: []*MessageBlock{ + { + Msg: &Message{ + Version: 1, + }, + }, + }, + } + r := newLegacyRecords(set) + + exp, err := encode(set, nil) + if err != nil { + t.Fatal(err) + } + buf, err := encode(&r, nil) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, exp) { + t.Errorf("Wrong encoding for legacy records, wanted %v, got %v", exp, buf) + } + + set = &MessageSet{} + r = newLegacyRecords(nil) + + err = decode(exp, set) + if err != nil { + t.Fatal(err) + } + err = decode(buf, &r) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(set, r.msgSet) { + t.Errorf("Wrong decoding for legacy records, wanted %#+v, got %#+v", set, r.msgSet) + } + + n, err := r.numRecords() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("Wrong number of records, wanted 1, got %d", n) + } +} + +func TestDefaultRecords(t *testing.T) { + batch := &RecordBatch{ + Version: 2, + Records: []*Record{ + { + Value: []byte{1}, + }, + }, + } + + r := newDefaultRecords(batch) + + exp, err := encode(batch, nil) + if err != nil { + t.Fatal(err) + } + buf, err := encode(&r, nil) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(buf, exp) { + t.Errorf("Wrong encoding for default records, wanted %v, got %v", exp, buf) + } + + batch = &RecordBatch{} + r = newDefaultRecords(nil) + + err = decode(exp, batch) + if err != nil { + t.Fatal(err) + } + err = decode(buf, &r) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(batch, r.recordBatch) { + t.Errorf("Wrong decoding for default records, wanted %#+v, got %#+v", batch, r.recordBatch) + } + + n, err := r.numRecords() + if err != nil { + t.Fatal(err) + } + if n != 1 { + t.Errorf("Wrong number of records, wanted 1, got %d", n) + } +} From 61fb33edde7db03b8ecf6ea476fcc2eb9a81f15d Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Thu, 26 Oct 2017 16:55:11 +0100 Subject: [PATCH 05/17] Rename Header to RecordHeader --- record.go | 12 ++++++------ record_test.go | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/record.go b/record.go index 427aef1c3..a9c4a5d97 100644 --- a/record.go +++ b/record.go @@ -6,19 +6,19 @@ const ( controlMask = 0x20 ) -type Header struct { +type RecordHeader struct { Key []byte Value []byte } -func (h *Header) encode(pe packetEncoder) error { +func (h *RecordHeader) encode(pe packetEncoder) error { if err := pe.putVarintBytes(h.Key); err != nil { return err } return pe.putVarintBytes(h.Value) } -func (h *Header) decode(pd packetDecoder) (err error) { +func (h *RecordHeader) decode(pd packetDecoder) (err error) { if h.Key, err = pd.getVarintBytes(); err != nil { return err } @@ -35,7 +35,7 @@ type Record struct { OffsetDelta int64 Key []byte Value []byte - Headers []*Header + Headers []*RecordHeader lengthComputed bool length int64 @@ -105,10 +105,10 @@ func (r *Record) decode(pd packetDecoder) (err error) { } if numHeaders >= 0 { - r.Headers = make([]*Header, numHeaders) + r.Headers = make([]*RecordHeader, numHeaders) } for i := int64(0); i < numHeaders; i++ { - hdr := new(Header) + hdr := new(RecordHeader) if err := hdr.decode(pd); err != nil { return err } diff --git a/record_test.go b/record_test.go index b1258f37f..66ab78cde 100644 --- a/record_test.go +++ b/record_test.go @@ -63,7 +63,7 @@ var recordBatchTestCases = []struct { TimestampDelta: 5, Key: []byte{1, 2, 3, 4}, Value: []byte{5, 6, 7}, - Headers: []*Header{{ + Headers: []*RecordHeader{{ Key: []byte{8, 9, 10}, Value: []byte{11, 12}, }}, @@ -108,7 +108,7 @@ var recordBatchTestCases = []struct { TimestampDelta: 5, Key: []byte{1, 2, 3, 4}, Value: []byte{5, 6, 7}, - Headers: []*Header{{ + Headers: []*RecordHeader{{ Key: []byte{8, 9, 10}, Value: []byte{11, 12}, }}, @@ -159,7 +159,7 @@ var recordBatchTestCases = []struct { TimestampDelta: 5, Key: []byte{1, 2, 3, 4}, Value: []byte{5, 6, 7}, - Headers: []*Header{{ + Headers: []*RecordHeader{{ Key: []byte{8, 9, 10}, Value: []byte{11, 12}, }}, @@ -192,7 +192,7 @@ var recordBatchTestCases = []struct { TimestampDelta: 5, Key: []byte{1, 2, 3, 4}, Value: []byte{5, 6, 7}, - Headers: []*Header{{ + Headers: []*RecordHeader{{ Key: []byte{8, 9, 10}, Value: []byte{11, 12}, }}, From 6d52b9931e531dc4068a0baf0afa286ab1faf1de Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Thu, 26 Oct 2017 16:55:28 +0100 Subject: [PATCH 06/17] Add test coverage for Records.isControl() and Records.isPartial() --- records_test.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/records_test.go b/records_test.go index 1b9f04708..ff3e64412 100644 --- a/records_test.go +++ b/records_test.go @@ -53,6 +53,22 @@ func TestLegacyRecords(t *testing.T) { if n != 1 { t.Errorf("Wrong number of records, wanted 1, got %d", n) } + + p, err := r.isPartial() + if err != nil { + t.Fatal(err) + } + if p { + t.Errorf("MessageSet shouldn't have a partial trailing message") + } + + c, err := r.isControl() + if err != nil { + t.Fatal(err) + } + if c { + t.Errorf("MessageSet can't be a control batch") + } } func TestDefaultRecords(t *testing.T) { @@ -102,4 +118,20 @@ func TestDefaultRecords(t *testing.T) { if n != 1 { t.Errorf("Wrong number of records, wanted 1, got %d", n) } + + p, err := r.isPartial() + if err != nil { + t.Fatal(err) + } + if p { + t.Errorf("RecordBatch shouldn't have a partial trailing record") + } + + c, err := r.isControl() + if err != nil { + t.Fatal(err) + } + if c { + t.Errorf("RecordBatch shouldn't be a control batch") + } } From b37b1580be487185e0f6426568c079097ee0ed3e Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Mon, 30 Oct 2017 14:21:10 +0000 Subject: [PATCH 07/17] Introduce dynamicPushEncoders Added dynamicPushEncoder interface that extends the pushEncoder with an adjustLength method that will be called by prepEncoder.pop() time so that it computes the actual length of the field. Also made varintLengthField implement this method so we can avoid a needless run of prepEncoder for uncompressed records. --- length_field.go | 35 ++++++++++++++++++++++++++------- packet_encoder.go | 10 ++++++++++ prep_encoder.go | 9 +++++++++ record.go | 50 ++++++++++------------------------------------- record_batch.go | 32 +++++++++++++++--------------- record_test.go | 8 ++------ 6 files changed, 75 insertions(+), 69 deletions(-) diff --git a/length_field.go b/length_field.go index 279791b18..1f49fed58 100644 --- a/length_field.go +++ b/length_field.go @@ -31,22 +31,43 @@ func (l *lengthField) check(curOffset int, buf []byte) error { type varintLengthField struct { startOffset int length int64 + adjusted bool + size int } -func newVarintLengthField(pd packetDecoder) (*varintLengthField, error) { - n, err := pd.getVarint() - if err != nil { - return nil, err - } - return &varintLengthField{length: n}, nil +func (l *varintLengthField) decode(pd packetDecoder) error { + var err error + l.length, err = pd.getVarint() + return err } func (l *varintLengthField) saveOffset(in int) { l.startOffset = in } +func (l *varintLengthField) adjustLength(currOffset int) int { + l.adjusted = true + + var tmp [binary.MaxVarintLen64]byte + l.length = int64(currOffset - l.startOffset - l.size) + + newSize := binary.PutVarint(tmp[:], l.length) + diff := newSize - l.size + l.size = newSize + + return diff +} + func (l *varintLengthField) reserveLength() int { - return 0 + return l.size +} + +func (l *varintLengthField) run(curOffset int, buf []byte) error { + if !l.adjusted { + return PacketEncodingError{"varintLengthField.run called before adjustLength"} + } + binary.PutVarint(buf[l.startOffset:], l.length) + return nil } func (l *varintLengthField) check(curOffset int, buf []byte) error { diff --git a/packet_encoder.go b/packet_encoder.go index 8b9a61675..0d3cfc857 100644 --- a/packet_encoder.go +++ b/packet_encoder.go @@ -50,3 +50,13 @@ type pushEncoder interface { // of data to the saved offset, based on the data between the saved offset and curOffset. run(curOffset int, buf []byte) error } + +// dynamicPushEncoder extends the interface of pushEncoder for uses cases where the length of the +// fields itself is unknown until its value was computed (for instance varint encoded lenght +// fields). +type dynamicPushEncoder interface { + pushEncoder + + // Called during pop() to adjust the length of the field. + adjustLength(currOffset int) int +} diff --git a/prep_encoder.go b/prep_encoder.go index 97b0b81d2..3c890fcdc 100644 --- a/prep_encoder.go +++ b/prep_encoder.go @@ -9,6 +9,7 @@ import ( ) type prepEncoder struct { + stack []pushEncoder length int } @@ -119,10 +120,18 @@ func (pe *prepEncoder) offset() int { // stackable func (pe *prepEncoder) push(in pushEncoder) { + in.saveOffset(pe.length) pe.length += in.reserveLength() + pe.stack = append(pe.stack, in) } func (pe *prepEncoder) pop() error { + in := pe.stack[len(pe.stack)-1] + pe.stack = pe.stack[:len(pe.stack)-1] + if dpe, ok := in.(dynamicPushEncoder); ok { + pe.length += dpe.adjustLength(pe.length) + } + return nil } diff --git a/record.go b/record.go index a9c4a5d97..9d5d32153 100644 --- a/record.go +++ b/record.go @@ -1,7 +1,5 @@ package sarama -import "encoding/binary" - const ( controlMask = 0x20 ) @@ -37,17 +35,12 @@ type Record struct { Value []byte Headers []*RecordHeader - lengthComputed bool - length int64 - totalLength int + length varintLengthField + totalLength int } func (r *Record) encode(pe packetEncoder) error { - if err := r.computeLength(); err != nil { - return err - } - - pe.putVarint(r.length) + pe.push(&r.length) pe.putInt8(r.Attributes) pe.putVarint(r.TimestampDelta) pe.putVarint(r.OffsetDelta) @@ -65,19 +58,16 @@ func (r *Record) encode(pe packetEncoder) error { } } - return nil + return pe.pop() } func (r *Record) decode(pd packetDecoder) (err error) { - length, err := newVarintLengthField(pd) - if err != nil { + if err := r.length.decode(pd); err != nil { return err } - if err = pd.push(length); err != nil { + if err = pd.push(&r.length); err != nil { return err } - r.length = length.length - r.lengthComputed = true if r.Attributes, err = pd.getInt8(); err != nil { return err @@ -118,32 +108,12 @@ func (r *Record) decode(pd packetDecoder) (err error) { return pd.pop() } -// Because the length is varint we can't reserve a fixed amount of bytes for it. -// We use the prepEncoder to figure out the length of the record and then we cache it. -func (r *Record) computeLength() error { - if !r.lengthComputed { - r.lengthComputed = true - - var prep prepEncoder - if err := r.encode(&prep); err != nil { - return err - } - // subtract 1 because we don't want to include the length field itself (which 1 byte, the - // length of varint encoding of 0) - r.length = int64(prep.length) - 1 - } - - return nil -} - func (r *Record) getTotalLength() (int, error) { - if r.totalLength == 0 { - if err := r.computeLength(); err != nil { + var prep prepEncoder + if !r.length.adjusted { + if err := r.encode(&prep); err != nil { return 0, err } - var buf [binary.MaxVarintLen64]byte - r.totalLength = int(r.length) + binary.PutVarint(buf[:], r.length) } - - return r.totalLength, nil + return int(r.length.length) + r.length.size, nil } diff --git a/record_batch.go b/record_batch.go index 83d6f25c7..c24b500f9 100644 --- a/record_batch.go +++ b/record_batch.go @@ -59,10 +59,7 @@ func (b *RecordBatch) encode(pe packetEncoder) error { if err := pe.pop(); err != nil { return err } - if err := pe.pop(); err != nil { - return err - } - return nil + return pe.pop() } var re packetEncoder @@ -72,14 +69,9 @@ func (b *RecordBatch) encode(pe packetEncoder) error { case CompressionNone: re = pe case CompressionGZIP, CompressionLZ4, CompressionSnappy: - for _, r := range b.Records { - l, err := r.getTotalLength() - if err != nil { - return err - } - b.recordsLen += l + if err := b.computeRecordsLength(); err != nil { + return err } - raw = make([]byte, b.recordsLen) re = &realEncoder{raw: raw} default: @@ -123,11 +115,7 @@ func (b *RecordBatch) encode(pe packetEncoder) error { if err := pe.pop(); err != nil { return err } - if err := pe.pop(); err != nil { - return err - } - - return nil + return pe.pop() } func (b *RecordBatch) decode(pd packetDecoder) (err error) { @@ -249,6 +237,18 @@ func (b *RecordBatch) computeAttributes() int16 { return attr } +func (b *RecordBatch) computeRecordsLength() error { + b.recordsLen = 0 + for _, r := range b.Records { + l, err := r.getTotalLength() + if err != nil { + return err + } + b.recordsLen += l + } + return nil +} + func (b *RecordBatch) addRecord(r *Record) { b.Records = append(b.Records, r) } diff --git a/record_test.go b/record_test.go index 66ab78cde..5393383a1 100644 --- a/record_test.go +++ b/record_test.go @@ -252,14 +252,10 @@ func TestRecordBatchDecoding(t *testing.T) { batch := RecordBatch{} testDecodable(t, tc.name, &batch, tc.encoded) for _, r := range batch.Records { - if _, err := r.getTotalLength(); err != nil { - t.Fatalf("Unexpected error: %v", err) - } + r.length = varintLengthField{} } for _, r := range tc.batch.Records { - if _, err := r.getTotalLength(); err != nil { - t.Fatalf("Unexpected error: %v", err) - } + r.length = varintLengthField{} } if !reflect.DeepEqual(batch, tc.batch) { t.Errorf(spew.Sprintf("invalid decode of %s\ngot %+v\nwanted %+v", tc.name, batch, tc.batch)) From 371165d592ebf0920198be616f56fc7ddf986a93 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 11:50:17 +0000 Subject: [PATCH 08/17] Clarify the expected return value of ajustLength --- packet_encoder.go | 1 + 1 file changed, 1 insertion(+) diff --git a/packet_encoder.go b/packet_encoder.go index 0d3cfc857..98ccdf023 100644 --- a/packet_encoder.go +++ b/packet_encoder.go @@ -58,5 +58,6 @@ type dynamicPushEncoder interface { pushEncoder // Called during pop() to adjust the length of the field. + // It should return the difference in bytes between the last computed length and current length. adjustLength(currOffset int) int } From 135aca9ac6b38c19cde1c6efc75bbf1c8017b4ee Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 11:51:37 +0000 Subject: [PATCH 09/17] Don't cache the length --- record.go | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/record.go b/record.go index 9d5d32153..8cfd58909 100644 --- a/record.go +++ b/record.go @@ -35,8 +35,7 @@ type Record struct { Value []byte Headers []*RecordHeader - length varintLengthField - totalLength int + length varintLengthField } func (r *Record) encode(pe packetEncoder) error { @@ -110,10 +109,6 @@ func (r *Record) decode(pd packetDecoder) (err error) { func (r *Record) getTotalLength() (int, error) { var prep prepEncoder - if !r.length.adjusted { - if err := r.encode(&prep); err != nil { - return 0, err - } - } - return int(r.length.length) + r.length.size, nil + err := r.encode(&prep) + return prep.length, err } From b51e2317f406d669fed60e6746006376b9263483 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 11:51:57 +0000 Subject: [PATCH 10/17] Make an encoder/decoder for records array --- record_batch.go | 67 ++++++++++++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/record_batch.go b/record_batch.go index c24b500f9..f5fa0ad2a 100644 --- a/record_batch.go +++ b/record_batch.go @@ -12,6 +12,28 @@ import ( const recordBatchOverhead = 49 +type recordsArray []*Record + +func (e recordsArray) encode(pe packetEncoder) error { + for _, r := range e { + if err := r.encode(pe); err != nil { + return err + } + } + return nil +} + +func (e recordsArray) decode(pd packetDecoder) error { + for i := range e { + rec := &Record{} + if err := rec.decode(pd); err != nil { + return err + } + e[i] = rec + } + return nil +} + type RecordBatch struct { FirstOffset int64 PartitionLeaderEpoch int32 @@ -62,29 +84,18 @@ func (b *RecordBatch) encode(pe packetEncoder) error { return pe.pop() } - var re packetEncoder var raw []byte - - switch b.Codec { - case CompressionNone: - re = pe - case CompressionGZIP, CompressionLZ4, CompressionSnappy: - if err := b.computeRecordsLength(); err != nil { + if b.Codec != CompressionNone { + var err error + if raw, err = encode(recordsArray(b.Records), nil); err != nil { return err } - raw = make([]byte, b.recordsLen) - re = &realEncoder{raw: raw} - default: - return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)} } - - for _, r := range b.Records { - if err := r.encode(re); err != nil { + switch b.Codec { + case CompressionNone: + if err := recordsArray(b.Records).encode(pe); err != nil { return err } - } - - switch b.Codec { case CompressionGZIP: var buf bytes.Buffer writer := gzip.NewWriter(&buf) @@ -107,6 +118,8 @@ func (b *RecordBatch) encode(pe packetEncoder) error { return err } b.compressedRecords = buf.Bytes() + default: + return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)} } if err := pe.putRawBytes(b.compressedRecords); err != nil { return err @@ -211,22 +224,14 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) { default: return PacketDecodingError{fmt.Sprintf("invalid compression specified (%d)", b.Codec)} } - recPd := &realDecoder{raw: recBuffer} - for i := 0; i < numRecs; i++ { - rec := &Record{} - if err = rec.decode(recPd); err != nil { - if err == ErrInsufficientData { - b.PartialTrailingRecord = true - b.Records = nil - return nil - } - return err - } - b.Records[i] = rec + err = decode(recBuffer, recordsArray(b.Records)) + if err == ErrInsufficientData { + b.PartialTrailingRecord = true + b.Records = nil + return nil } - - return nil + return err } func (b *RecordBatch) computeAttributes() int16 { From 5ffc7bf531ddacf70438313c38c2815b462a07a2 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 13:02:57 +0000 Subject: [PATCH 11/17] Simplify uncompressed records length computation --- record.go | 6 ------ record_batch.go | 17 ++++------------- 2 files changed, 4 insertions(+), 19 deletions(-) diff --git a/record.go b/record.go index 8cfd58909..df4090aee 100644 --- a/record.go +++ b/record.go @@ -106,9 +106,3 @@ func (r *Record) decode(pd packetDecoder) (err error) { return pd.pop() } - -func (r *Record) getTotalLength() (int, error) { - var prep prepEncoder - err := r.encode(&prep) - return prep.length, err -} diff --git a/record_batch.go b/record_batch.go index f5fa0ad2a..1439481f3 100644 --- a/record_batch.go +++ b/record_batch.go @@ -50,7 +50,7 @@ type RecordBatch struct { PartialTrailingRecord bool compressedRecords []byte - recordsLen int + recordsLen int // uncompressed records size } func (b *RecordBatch) encode(pe packetEncoder) error { @@ -90,12 +90,15 @@ func (b *RecordBatch) encode(pe packetEncoder) error { if raw, err = encode(recordsArray(b.Records), nil); err != nil { return err } + b.recordsLen = len(raw) } switch b.Codec { case CompressionNone: + offset := pe.offset() if err := recordsArray(b.Records).encode(pe); err != nil { return err } + b.recordsLen = pe.offset() - offset case CompressionGZIP: var buf bytes.Buffer writer := gzip.NewWriter(&buf) @@ -242,18 +245,6 @@ func (b *RecordBatch) computeAttributes() int16 { return attr } -func (b *RecordBatch) computeRecordsLength() error { - b.recordsLen = 0 - for _, r := range b.Records { - l, err := r.getTotalLength() - if err != nil { - return err - } - b.recordsLen += l - } - return nil -} - func (b *RecordBatch) addRecord(r *Record) { b.Records = append(b.Records, r) } From 4a9a2bcc4756c7143c3ef53e330a0ad229c17ff3 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 14:43:59 +0000 Subject: [PATCH 12/17] Get rid of size and adjusted fields in varintLengthField --- length_field.go | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) diff --git a/length_field.go b/length_field.go index 1f49fed58..662550bfd 100644 --- a/length_field.go +++ b/length_field.go @@ -31,8 +31,6 @@ func (l *lengthField) check(curOffset int, buf []byte) error { type varintLengthField struct { startOffset int length int64 - adjusted bool - size int } func (l *varintLengthField) decode(pd packetDecoder) error { @@ -46,26 +44,18 @@ func (l *varintLengthField) saveOffset(in int) { } func (l *varintLengthField) adjustLength(currOffset int) int { - l.adjusted = true + oldFieldSize := l.reserveLength() + l.length = int64(currOffset - l.startOffset - oldFieldSize) - var tmp [binary.MaxVarintLen64]byte - l.length = int64(currOffset - l.startOffset - l.size) - - newSize := binary.PutVarint(tmp[:], l.length) - diff := newSize - l.size - l.size = newSize - - return diff + return l.reserveLength() - oldFieldSize } func (l *varintLengthField) reserveLength() int { - return l.size + var tmp [binary.MaxVarintLen64]byte + return binary.PutVarint(tmp[:], l.length) } func (l *varintLengthField) run(curOffset int, buf []byte) error { - if !l.adjusted { - return PacketEncodingError{"varintLengthField.run called before adjustLength"} - } binary.PutVarint(buf[l.startOffset:], l.length) return nil } From c302872c8f1fa6d65cf93eb5cf526890ae6cbae8 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 14:44:52 +0000 Subject: [PATCH 13/17] Fix typo --- packet_encoder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packet_encoder.go b/packet_encoder.go index 98ccdf023..b356ab8e8 100644 --- a/packet_encoder.go +++ b/packet_encoder.go @@ -52,7 +52,7 @@ type pushEncoder interface { } // dynamicPushEncoder extends the interface of pushEncoder for uses cases where the length of the -// fields itself is unknown until its value was computed (for instance varint encoded lenght +// fields itself is unknown until its value was computed (for instance varint encoded length // fields). type dynamicPushEncoder interface { pushEncoder From bd304f1b4003bc446abdbf0efa128a6347211a3f Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 14:45:09 +0000 Subject: [PATCH 14/17] Refactor record batch encoding --- record_batch.go | 105 ++++++++++++++++++++++++------------------------ 1 file changed, 52 insertions(+), 53 deletions(-) diff --git a/record_batch.go b/record_batch.go index 1439481f3..80963ee16 100644 --- a/record_batch.go +++ b/record_batch.go @@ -74,55 +74,8 @@ func (b *RecordBatch) encode(pe packetEncoder) error { return err } - if b.compressedRecords != nil { - if err := pe.putRawBytes(b.compressedRecords); err != nil { - return err - } - if err := pe.pop(); err != nil { - return err - } - return pe.pop() - } - - var raw []byte - if b.Codec != CompressionNone { - var err error - if raw, err = encode(recordsArray(b.Records), nil); err != nil { - return err - } - b.recordsLen = len(raw) - } - switch b.Codec { - case CompressionNone: - offset := pe.offset() - if err := recordsArray(b.Records).encode(pe); err != nil { - return err - } - b.recordsLen = pe.offset() - offset - case CompressionGZIP: - var buf bytes.Buffer - writer := gzip.NewWriter(&buf) - if _, err := writer.Write(raw); err != nil { - return err - } - if err := writer.Close(); err != nil { - return err - } - b.compressedRecords = buf.Bytes() - case CompressionSnappy: - b.compressedRecords = snappy.Encode(raw) - case CompressionLZ4: - var buf bytes.Buffer - writer := lz4.NewWriter(&buf) - if _, err := writer.Write(raw); err != nil { - return err - } - if err := writer.Close(); err != nil { - return err - } - b.compressedRecords = buf.Bytes() - default: - return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)} + if b.compressedRecords == nil { + b.encodeRecords(pe) } if err := pe.putRawBytes(b.compressedRecords); err != nil { return err @@ -139,8 +92,8 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) { return err } - var batchLen int32 - if batchLen, err = pd.getInt32(); err != nil { + batchLen, err := pd.getInt32() + if err != nil { return err } @@ -156,8 +109,8 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) { return err } - var attributes int16 - if attributes, err = pd.getInt16(); err != nil { + attributes, err := pd.getInt16() + if err != nil { return err } b.Codec = CompressionCodec(int8(attributes) & compressionCodecMask) @@ -237,6 +190,52 @@ func (b *RecordBatch) decode(pd packetDecoder) (err error) { return err } +func (b *RecordBatch) encodeRecords(pe packetEncoder) error { + var raw []byte + if b.Codec != CompressionNone { + var err error + if raw, err = encode(recordsArray(b.Records), nil); err != nil { + return err + } + b.recordsLen = len(raw) + } + + switch b.Codec { + case CompressionNone: + offset := pe.offset() + if err := recordsArray(b.Records).encode(pe); err != nil { + return err + } + b.recordsLen = pe.offset() - offset + case CompressionGZIP: + var buf bytes.Buffer + writer := gzip.NewWriter(&buf) + if _, err := writer.Write(raw); err != nil { + return err + } + if err := writer.Close(); err != nil { + return err + } + b.compressedRecords = buf.Bytes() + case CompressionSnappy: + b.compressedRecords = snappy.Encode(raw) + case CompressionLZ4: + var buf bytes.Buffer + writer := lz4.NewWriter(&buf) + if _, err := writer.Write(raw); err != nil { + return err + } + if err := writer.Close(); err != nil { + return err + } + b.compressedRecords = buf.Bytes() + default: + return PacketEncodingError{fmt.Sprintf("unsupported compression codec (%d)", b.Codec)} + } + + return nil +} + func (b *RecordBatch) computeAttributes() int16 { attr := int16(b.Codec) & int16(compressionCodecMask) if b.Control { From 033fed9b9208f037a16d2ee5dca2732f9894c674 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 15:00:13 +0000 Subject: [PATCH 15/17] varintLengthField.reserveLen() should return 0 during decode --- length_field.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/length_field.go b/length_field.go index 662550bfd..89731d643 100644 --- a/length_field.go +++ b/length_field.go @@ -31,10 +31,12 @@ func (l *lengthField) check(curOffset int, buf []byte) error { type varintLengthField struct { startOffset int length int64 + consumed bool } func (l *varintLengthField) decode(pd packetDecoder) error { var err error + l.consumed = true l.length, err = pd.getVarint() return err } @@ -51,6 +53,10 @@ func (l *varintLengthField) adjustLength(currOffset int) int { } func (l *varintLengthField) reserveLength() int { + if l.consumed { // the field was consumed during the decode + l.consumed = false + return 0 + } var tmp [binary.MaxVarintLen64]byte return binary.PutVarint(tmp[:], l.length) } From 808ea149d84fbfb14227982b41c124c03ed8d6e6 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 15:01:06 +0000 Subject: [PATCH 16/17] Check records encoding error --- record_batch.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/record_batch.go b/record_batch.go index 80963ee16..9379e4378 100644 --- a/record_batch.go +++ b/record_batch.go @@ -75,7 +75,9 @@ func (b *RecordBatch) encode(pe packetEncoder) error { } if b.compressedRecords == nil { - b.encodeRecords(pe) + if err := b.encodeRecords(pe); err != nil { + return err + } } if err := pe.putRawBytes(b.compressedRecords); err != nil { return err From ff1f79c54b5114ebf7d3e4164591933ee70c6569 Mon Sep 17 00:00:00 2001 From: Vlad Hanciuta Date: Tue, 31 Oct 2017 15:25:43 +0000 Subject: [PATCH 17/17] Add dynamicPushDecoder interface dynamicPushDecoder extends pushDecoder for cases when the field has variable length. Also, changed varintLengthField to make use of the new interface/ --- length_field.go | 8 +------- packet_decoder.go | 9 +++++++++ real_decoder.go | 15 +++++++++++---- record.go | 3 --- 4 files changed, 21 insertions(+), 14 deletions(-) diff --git a/length_field.go b/length_field.go index 89731d643..576b1a6f6 100644 --- a/length_field.go +++ b/length_field.go @@ -31,12 +31,10 @@ func (l *lengthField) check(curOffset int, buf []byte) error { type varintLengthField struct { startOffset int length int64 - consumed bool } func (l *varintLengthField) decode(pd packetDecoder) error { var err error - l.consumed = true l.length, err = pd.getVarint() return err } @@ -53,10 +51,6 @@ func (l *varintLengthField) adjustLength(currOffset int) int { } func (l *varintLengthField) reserveLength() int { - if l.consumed { // the field was consumed during the decode - l.consumed = false - return 0 - } var tmp [binary.MaxVarintLen64]byte return binary.PutVarint(tmp[:], l.length) } @@ -67,7 +61,7 @@ func (l *varintLengthField) run(curOffset int, buf []byte) error { } func (l *varintLengthField) check(curOffset int, buf []byte) error { - if int64(curOffset-l.startOffset) != l.length { + if int64(curOffset-l.startOffset-l.reserveLength()) != l.length { return PacketDecodingError{"length field invalid"} } diff --git a/packet_decoder.go b/packet_decoder.go index 5a3b461a5..387d12cb1 100644 --- a/packet_decoder.go +++ b/packet_decoder.go @@ -46,3 +46,12 @@ type pushDecoder interface { // of data from the saved offset, and verify it based on the data between the saved offset and curOffset. check(curOffset int, buf []byte) error } + +// dynamicPushDecoder extends the interface of pushDecoder for uses cases where the length of the +// fields itself is unknown until its value was decoded (for instance varint encoded length +// fields). +// During push, dynamicPushDecoder.decode() method will be called instead of reserveLength() +type dynamicPushDecoder interface { + pushDecoder + decoder +} diff --git a/real_decoder.go b/real_decoder.go index 05bcd207f..48ab8e86e 100644 --- a/real_decoder.go +++ b/real_decoder.go @@ -260,10 +260,17 @@ func (rd *realDecoder) getRawBytes(length int) ([]byte, error) { func (rd *realDecoder) push(in pushDecoder) error { in.saveOffset(rd.off) - reserve := in.reserveLength() - if rd.remaining() < reserve { - rd.off = len(rd.raw) - return ErrInsufficientData + var reserve int + if dpd, ok := in.(dynamicPushDecoder); ok { + if err := dpd.decode(rd); err != nil { + return err + } + } else { + reserve = in.reserveLength() + if rd.remaining() < reserve { + rd.off = len(rd.raw) + return ErrInsufficientData + } } rd.stack = append(rd.stack, in) diff --git a/record.go b/record.go index df4090aee..4da3783a5 100644 --- a/record.go +++ b/record.go @@ -61,9 +61,6 @@ func (r *Record) encode(pe packetEncoder) error { } func (r *Record) decode(pd packetDecoder) (err error) { - if err := r.length.decode(pd); err != nil { - return err - } if err = pd.push(&r.length); err != nil { return err }