From fe4198b648499375651b7fece0b8489ea07d029f Mon Sep 17 00:00:00 2001 From: twmb Date: Thu, 9 Jul 2015 16:11:03 -0700 Subject: [PATCH] enforce minimum and maximum msg size when reading from and writing to diskqueue Prevents panics in the face of corrupt messages. --- nsqd/channel.go | 2 ++ nsqd/diskqueue.go | 23 ++++++++++++++++++++--- nsqd/diskqueue_test.go | 41 ++++++++++++++++++++++++++--------------- nsqd/message.go | 7 +++++-- nsqd/topic.go | 2 ++ 5 files changed, 55 insertions(+), 20 deletions(-) diff --git a/nsqd/channel.go b/nsqd/channel.go index 6bcb58a7e..d6662685a 100644 --- a/nsqd/channel.go +++ b/nsqd/channel.go @@ -105,6 +105,8 @@ func NewChannel(topicName string, channelName string, ctx *context, c.backend = newDiskQueue(backendName, ctx.nsqd.opts.DataPath, ctx.nsqd.opts.MaxBytesPerFile, + int32(minValidMsgLength), + int32(ctx.nsqd.opts.MaxMsgSize), ctx.nsqd.opts.SyncEvery, ctx.nsqd.opts.SyncTimeout, ctx.nsqd.opts.Logger) diff --git a/nsqd/diskqueue.go b/nsqd/diskqueue.go index 75ec15dfd..d1109ce08 100644 --- a/nsqd/diskqueue.go +++ b/nsqd/diskqueue.go @@ -32,7 +32,9 @@ type diskQueue struct { // instantiation time metadata name string dataPath string - maxBytesPerFile int64 // currently this cannot change once created + maxBytesPerFile int64 // currently this cannot change once created + minMsgSize int32 + maxMsgSize int32 syncEvery int64 // number of writes per fsync syncTimeout time.Duration // duration of time per fsync exitFlag int32 @@ -65,12 +67,15 @@ type diskQueue struct { // newDiskQueue instantiates a new instance of diskQueue, retrieving metadata // from the filesystem and starting the read ahead goroutine func newDiskQueue(name string, dataPath string, maxBytesPerFile int64, + minMsgSize int32, maxMsgSize int32, syncEvery int64, syncTimeout time.Duration, logger logger) BackendQueue { d := diskQueue{ name: name, dataPath: dataPath, maxBytesPerFile: maxBytesPerFile, + minMsgSize: minMsgSize, + maxMsgSize: maxMsgSize, readChan: make(chan []byte), writeChan: make(chan []byte), writeResponseChan: make(chan error), @@ -261,6 +266,14 @@ func (d *diskQueue) readOne() ([]byte, error) { return nil, err } + if msgSize < d.minMsgSize || msgSize > d.maxMsgSize { + // this file is corrupt and we have no reasonable guarantee on + // where a new message should begin + d.readFile.Close() + d.readFile = nil + return nil, fmt.Errorf("invalid message read size (%d)", msgSize) + } + readBuf := make([]byte, msgSize) _, err = io.ReadFull(d.reader, readBuf) if err != nil { @@ -316,10 +329,14 @@ func (d *diskQueue) writeOne(data []byte) error { } } - dataLen := len(data) + dataLen := int32(len(data)) + + if dataLen < d.minMsgSize || dataLen > d.maxMsgSize { + return fmt.Errorf("invalid message write size (%d)", dataLen) + } d.writeBuf.Reset() - err = binary.Write(&d.writeBuf, binary.BigEndian, int32(dataLen)) + err = binary.Write(&d.writeBuf, binary.BigEndian, dataLen) if err != nil { return err } diff --git a/nsqd/diskqueue_test.go b/nsqd/diskqueue_test.go index bafec2118..83037385c 100644 --- a/nsqd/diskqueue_test.go +++ b/nsqd/diskqueue_test.go @@ -2,6 +2,7 @@ package nsqd import ( "bufio" + "bytes" "fmt" "io/ioutil" "os" @@ -22,7 +23,7 @@ func TestDiskQueue(t *testing.T) { panic(err) } defer os.RemoveAll(tmpDir) - dq := newDiskQueue(dqName, tmpDir, 1024, 2500, 2*time.Second, l) + dq := newDiskQueue(dqName, tmpDir, 1024, 4, 1<<10, 2500, 2*time.Second, l) nequal(t, dq, nil) equal(t, dq.Depth(), int64(0)) @@ -43,11 +44,12 @@ func TestDiskQueueRoll(t *testing.T) { panic(err) } defer os.RemoveAll(tmpDir) - dq := newDiskQueue(dqName, tmpDir, 100, 2500, 2*time.Second, l) + msg := bytes.Repeat([]byte{0}, 10) + ml := int64(len(msg)) + dq := newDiskQueue(dqName, tmpDir, 9*(ml+4), int32(ml), 1<<10, 2500, 2*time.Second, l) nequal(t, dq, nil) equal(t, dq.Depth(), int64(0)) - msg := []byte("aaaaaaaaaa") for i := 0; i < 10; i++ { err := dq.Put(msg) equal(t, err, nil) @@ -55,7 +57,7 @@ func TestDiskQueueRoll(t *testing.T) { } equal(t, dq.(*diskQueue).writeFileNum, int64(1)) - equal(t, dq.(*diskQueue).writePos, int64(28)) + equal(t, dq.(*diskQueue).writePos, int64(0)) } func assertFileNotExist(t *testing.T, fn string) { @@ -72,12 +74,11 @@ func TestDiskQueueEmpty(t *testing.T) { panic(err) } defer os.RemoveAll(tmpDir) - dq := newDiskQueue(dqName, tmpDir, 100, 2500, 2*time.Second, l) + msg := bytes.Repeat([]byte{0}, 10) + dq := newDiskQueue(dqName, tmpDir, 100, 0, 1<<10, 2500, 2*time.Second, l) nequal(t, dq, nil) equal(t, dq.Depth(), int64(0)) - msg := []byte("aaaaaaaaaa") - for i := 0; i < 100; i++ { err := dq.Put(msg) equal(t, err, nil) @@ -140,9 +141,10 @@ func TestDiskQueueCorruption(t *testing.T) { panic(err) } defer os.RemoveAll(tmpDir) - dq := newDiskQueue(dqName, tmpDir, 1000, 5, 2*time.Second, l) + // require a non-zero message length for the corrupt (len 0) test below + dq := newDiskQueue(dqName, tmpDir, 1000, 10, 1<<10, 5, 2*time.Second, l) - msg := make([]byte, 123) + msg := make([]byte, 123) // 127 bytes per message, 8 (1016 bytes) messages per file for i := 0; i < 25; i++ { dq.Put(msg) } @@ -151,9 +153,9 @@ func TestDiskQueueCorruption(t *testing.T) { // corrupt the 2nd file dqFn := dq.(*diskQueue).fileName(1) - os.Truncate(dqFn, 500) + os.Truncate(dqFn, 500) // 3 valid messages, 5 corrupted - for i := 0; i < 19; i++ { + for i := 0; i < 19; i++ { // 1 message leftover in 4th file equal(t, <-dq.ReadChan(), msg) } @@ -161,6 +163,15 @@ func TestDiskQueueCorruption(t *testing.T) { dqFn = dq.(*diskQueue).fileName(3) os.Truncate(dqFn, 100) + dq.Put(msg) // in 5th file + + equal(t, <-dq.ReadChan(), msg) + + // write a corrupt (len 0) message at the 5th (current) file + dq.(*diskQueue).writeFile.Write([]byte{0, 0, 0, 0}) + + // force a new 6th file - put into 5th, then readOne errors, then put into 6th + dq.Put(msg) dq.Put(msg) equal(t, <-dq.ReadChan(), msg) @@ -176,7 +187,7 @@ func TestDiskQueueTorture(t *testing.T) { panic(err) } defer os.RemoveAll(tmpDir) - dq := newDiskQueue(dqName, tmpDir, 262144, 2500, 2*time.Second, l) + dq := newDiskQueue(dqName, tmpDir, 262144, 0, 1<<10, 2500, 2*time.Second, l) nequal(t, dq, nil) equal(t, dq.Depth(), int64(0)) @@ -217,7 +228,7 @@ func TestDiskQueueTorture(t *testing.T) { t.Logf("restarting diskqueue") - dq = newDiskQueue(dqName, tmpDir, 262144, 2500, 2*time.Second, l) + dq = newDiskQueue(dqName, tmpDir, 262144, 0, 1<<10, 2500, 2*time.Second, l) nequal(t, dq, nil) equal(t, dq.Depth(), depth) @@ -265,7 +276,7 @@ func BenchmarkDiskQueuePut(b *testing.B) { panic(err) } defer os.RemoveAll(tmpDir) - dq := newDiskQueue(dqName, tmpDir, 1024768*100, 2500, 2*time.Second, l) + dq := newDiskQueue(dqName, tmpDir, 1024768*100, 0, 1<<10, 2500, 2*time.Second, l) size := 1024 b.SetBytes(int64(size)) data := make([]byte, size) @@ -333,7 +344,7 @@ func BenchmarkDiskQueueGet(b *testing.B) { panic(err) } defer os.RemoveAll(tmpDir) - dq := newDiskQueue(dqName, tmpDir, 1024768, 2500, 2*time.Second, l) + dq := newDiskQueue(dqName, tmpDir, 1024768, 0, 1<<10, 2500, 2*time.Second, l) for i := 0; i < b.N; i++ { dq.Put([]byte("aaaaaaaaaaaaaaaaaaaaaaaaaaa")) } diff --git a/nsqd/message.go b/nsqd/message.go index 67d8db86a..fea318b96 100644 --- a/nsqd/message.go +++ b/nsqd/message.go @@ -9,7 +9,10 @@ import ( "time" ) -const MsgIDLength = 16 +const ( + MsgIDLength = 16 + minValidMsgLength = MsgIDLength + 8 + 2 // Timestamp + Attempts +) type MessageID [MsgIDLength]byte @@ -66,7 +69,7 @@ func (m *Message) WriteTo(w io.Writer) (int64, error) { func decodeMessage(b []byte) (*Message, error) { var msg Message - if len(b) < 26 { + if len(b) < minValidMsgLength { return nil, fmt.Errorf("invalid message buffer size (%d)", len(b)) } diff --git a/nsqd/topic.go b/nsqd/topic.go index da614ac84..925c1529d 100644 --- a/nsqd/topic.go +++ b/nsqd/topic.go @@ -56,6 +56,8 @@ func NewTopic(topicName string, ctx *context, deleteCallback func(*Topic)) *Topi t.backend = newDiskQueue(topicName, ctx.nsqd.opts.DataPath, ctx.nsqd.opts.MaxBytesPerFile, + int32(minValidMsgLength), + int32(ctx.nsqd.opts.MaxMsgSize), ctx.nsqd.opts.SyncEvery, ctx.nsqd.opts.SyncTimeout, ctx.nsqd.opts.Logger)