diff --git a/.gitignore b/.gitignore index 6cb34d165..79ae74292 100644 --- a/.gitignore +++ b/.gitignore @@ -42,3 +42,6 @@ _testmain.go *.exe profile + +# vim stuff +*.sw[op] diff --git a/nsqd/client_v2.go b/nsqd/client_v2.go index 8d6652ece..8d549cb72 100644 --- a/nsqd/client_v2.go +++ b/nsqd/client_v2.go @@ -28,6 +28,7 @@ type IdentifyDataV2 struct { Deflate bool `json:"deflate"` DeflateLevel int `json:"deflate_level"` Snappy bool `json:"snappy"` + SampleRate int32 `json:"sample_rate"` } type ClientV2 struct { @@ -69,6 +70,9 @@ type ClientV2 struct { LongIdentifier string SubEventChan chan *Channel + SampleRate int32 + SampleRateUpdateChan chan int32 + // re-usable buffer for reading the 4-byte lengths off the wire lenBuf [4]byte lenSlice []byte @@ -108,6 +112,8 @@ func NewClientV2(id int64, conn net.Conn, context *Context) *ClientV2 { State: nsq.StateInit, SubEventChan: make(chan *Channel, 1), + SampleRateUpdateChan: make(chan int32, 1), + // heartbeats are client configurable but default to 30s Heartbeat: time.NewTicker(context.nsqd.options.clientTimeout / 2), HeartbeatInterval: context.nsqd.options.clientTimeout / 2, @@ -132,7 +138,11 @@ func (c *ClientV2) Identify(data IdentifyDataV2) error { if err != nil { return err } - return c.SetOutputBufferTimeout(data.OutputBufferTimeout) + err = c.SetOutputBufferTimeout(data.OutputBufferTimeout) + if err != nil { + return err + } + return c.SetSampleRate(data.SampleRate) } func (c *ClientV2) Stats() ClientStats { @@ -147,6 +157,7 @@ func (c *ClientV2) Stats() ClientStats { FinishCount: atomic.LoadUint64(&c.FinishCount), RequeueCount: atomic.LoadUint64(&c.RequeueCount), ConnectTime: c.ConnectTime.Unix(), + SampleRate: c.SampleRate, } } @@ -314,6 +325,22 @@ func (c *ClientV2) SetOutputBufferTimeout(desiredTimeout int) error { return nil } +func (c *ClientV2) SetSampleRate(sampleRate int32) error { + if sampleRate < 0 || sampleRate > 99 { + return errors.New(fmt.Sprintf("sample rate (%d) is invalid", sampleRate)) + } + + if sampleRate != 0 { + c.SampleRate = sampleRate + select { + case c.SampleRateUpdateChan <- sampleRate: + default: + } + } + + return nil +} + func (c *ClientV2) UpgradeTLS() error { c.Lock() defer c.Unlock() diff --git a/nsqd/main.go b/nsqd/main.go index eb1440612..a80cf94f3 100644 --- a/nsqd/main.go +++ b/nsqd/main.go @@ -8,6 +8,7 @@ import ( "hash/crc32" "io" "log" + "math/rand" "net" "os" "os/signal" @@ -163,6 +164,9 @@ func main() { nsqd.httpAddr = httpAddr nsqd.lookupdTCPAddrs = lookupdTCPAddrs + // Set the random seed + rand.Seed(time.Now().UTC().UnixNano()) + nsqd.LoadMetadata() err = nsqd.PersistMetadata() if err != nil { diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index aa5f1ea4a..fecd6a475 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -10,6 +10,7 @@ import ( "io" "log" "math" + "math/rand" "net" "sync/atomic" "time" @@ -173,6 +174,7 @@ func (p *ProtocolV2) messagePump(client *ClientV2) { // the pathological case of a channel on a low volume topic // with >1 clients having >1 RDY counts var flusherChan <-chan time.Time + var sampleRate int32 // v2 opportunistically buffers data to clients to reduce write system calls // we force flush in two cases: @@ -183,6 +185,7 @@ func (p *ProtocolV2) messagePump(client *ClientV2) { subEventChan := client.SubEventChan heartbeatUpdateChan := client.HeartbeatUpdateChan outputBufferTimeoutUpdateChan := client.OutputBufferTimeoutUpdateChan + sampleRateUpdateChan := client.SampleRateUpdateChan flushed := true for { @@ -245,10 +248,18 @@ func (p *ProtocolV2) messagePump(client *ClientV2) { if err != nil { goto exit } + case sampleRate = <-sampleRateUpdateChan: + sampleRateUpdateChan = nil case msg, ok := <-clientMsgChan: if !ok { goto exit } + + // if we are sampling, do so here + if sampleRate > 0 && rand.Int31n(100) > sampleRate { + continue + } + subChannel.StartInFlightTimeout(msg, client.ID) client.SendingMessage() err = p.SendMessage(client, msg, &buf) @@ -335,6 +346,7 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) DeflateLevel int `json:"deflate_level"` MaxDeflateLevel int `json:"max_deflate_level"` Snappy bool `json:"snappy"` + SampleRate int32 `json:"sample_rate"` }{ MaxRdyCount: p.context.nsqd.options.maxRdyCount, Version: util.BINARY_VERSION, @@ -345,6 +357,7 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) DeflateLevel: deflateLevel, MaxDeflateLevel: p.context.nsqd.options.maxDeflateLevel, Snappy: snappy, + SampleRate: client.SampleRate, }) if err != nil { panic("should never happen") diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index 7adeceb06..3de53a073 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -14,6 +14,7 @@ import ( "io/ioutil" "log" "math" + "math/rand" "net" "os" "runtime" @@ -900,6 +901,62 @@ func TestTLSDeflate(t *testing.T) { assert.Equal(t, data, []byte("OK")) } +func TestSampling(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + rand.Seed(time.Now().UTC().UnixNano()) + + num := 3000 + + *verbose = true + options := NewNsqdOptions() + options.maxRdyCount = int64(num) + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"sample_rate": int32(42)}) + r := struct { + SampleRate int32 `json:"sample_rate"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.SampleRate, int32(42)) + + topicName := "test_sampling" + strconv.Itoa(int(time.Now().Unix())) + topic := nsqd.GetTopic(topicName) + for i := 0; i < num; i++ { + msg := nsq.NewMessage(<-nsqd.idChan, []byte("test body")) + topic.PutMessage(msg) + } + channel := topic.GetChannel("ch") + + // let the topic drain into the channel + time.Sleep(50 * time.Millisecond) + + sub(t, conn, topicName, "ch") + err = nsq.Ready(num).Write(conn) + assert.Equal(t, err, nil) + + doneChan := make(chan int) + go func() { + for { + if channel.Depth() == 0 { + close(doneChan) + return + } + time.Sleep(5 * time.Millisecond) + } + }() + <-doneChan + // within 3% + assert.Equal(t, len(channel.inFlightMessages) <= int(float64(num)*0.45), true) + assert.Equal(t, len(channel.inFlightMessages) >= int(float64(num)*0.39), true) +} + func TestTLSSnappy(t *testing.T) { log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stdout) diff --git a/nsqd/stats.go b/nsqd/stats.go index 2cec2d4d7..4e23bf2dd 100644 --- a/nsqd/stats.go +++ b/nsqd/stats.go @@ -70,6 +70,7 @@ type ClientStats struct { FinishCount uint64 `json:"finish_count"` RequeueCount uint64 `json:"requeue_count"` ConnectTime int64 `json:"connect_ts"` + SampleRate int32 `json:"sample_rate"` } type Topics []*Topic