Skip to content

Commit

Permalink
Merge pull request #286 from elubow/channel_sampling
Browse files Browse the repository at this point in the history
nsqd: channel sampling
  • Loading branch information
mreiferson committed Dec 18, 2013
2 parents d2449ba + 8010652 commit deabf26
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,6 @@ _testmain.go
*.exe

profile

# vim stuff
*.sw[op]
29 changes: 28 additions & 1 deletion nsqd/client_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type IdentifyDataV2 struct {
Deflate bool `json:"deflate"`
DeflateLevel int `json:"deflate_level"`
Snappy bool `json:"snappy"`
SampleRate int32 `json:"sample_rate"`
}

type ClientV2 struct {
Expand Down Expand Up @@ -69,6 +70,9 @@ type ClientV2 struct {
LongIdentifier string
SubEventChan chan *Channel

SampleRate int32
SampleRateUpdateChan chan int32

// re-usable buffer for reading the 4-byte lengths off the wire
lenBuf [4]byte
lenSlice []byte
Expand Down Expand Up @@ -108,6 +112,8 @@ func NewClientV2(id int64, conn net.Conn, context *Context) *ClientV2 {
State: nsq.StateInit,
SubEventChan: make(chan *Channel, 1),

SampleRateUpdateChan: make(chan int32, 1),

// heartbeats are client configurable but default to 30s
Heartbeat: time.NewTicker(context.nsqd.options.clientTimeout / 2),
HeartbeatInterval: context.nsqd.options.clientTimeout / 2,
Expand All @@ -132,7 +138,11 @@ func (c *ClientV2) Identify(data IdentifyDataV2) error {
if err != nil {
return err
}
return c.SetOutputBufferTimeout(data.OutputBufferTimeout)
err = c.SetOutputBufferTimeout(data.OutputBufferTimeout)
if err != nil {
return err
}
return c.SetSampleRate(data.SampleRate)
}

func (c *ClientV2) Stats() ClientStats {
Expand All @@ -147,6 +157,7 @@ func (c *ClientV2) Stats() ClientStats {
FinishCount: atomic.LoadUint64(&c.FinishCount),
RequeueCount: atomic.LoadUint64(&c.RequeueCount),
ConnectTime: c.ConnectTime.Unix(),
SampleRate: c.SampleRate,
}
}

Expand Down Expand Up @@ -314,6 +325,22 @@ func (c *ClientV2) SetOutputBufferTimeout(desiredTimeout int) error {
return nil
}

func (c *ClientV2) SetSampleRate(sampleRate int32) error {
if sampleRate < 0 || sampleRate > 99 {
return errors.New(fmt.Sprintf("sample rate (%d) is invalid", sampleRate))
}

if sampleRate != 0 {
c.SampleRate = sampleRate
select {
case c.SampleRateUpdateChan <- sampleRate:
default:
}
}

return nil
}

func (c *ClientV2) UpgradeTLS() error {
c.Lock()
defer c.Unlock()
Expand Down
4 changes: 4 additions & 0 deletions nsqd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"hash/crc32"
"io"
"log"
"math/rand"
"net"
"os"
"os/signal"
Expand Down Expand Up @@ -163,6 +164,9 @@ func main() {
nsqd.httpAddr = httpAddr
nsqd.lookupdTCPAddrs = lookupdTCPAddrs

// Set the random seed
rand.Seed(time.Now().UTC().UnixNano())

nsqd.LoadMetadata()
err = nsqd.PersistMetadata()
if err != nil {
Expand Down
13 changes: 13 additions & 0 deletions nsqd/protocol_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"log"
"math"
"math/rand"
"net"
"sync/atomic"
"time"
Expand Down Expand Up @@ -173,6 +174,7 @@ func (p *ProtocolV2) messagePump(client *ClientV2) {
// the pathological case of a channel on a low volume topic
// with >1 clients having >1 RDY counts
var flusherChan <-chan time.Time
var sampleRate int32

// v2 opportunistically buffers data to clients to reduce write system calls
// we force flush in two cases:
Expand All @@ -183,6 +185,7 @@ func (p *ProtocolV2) messagePump(client *ClientV2) {
subEventChan := client.SubEventChan
heartbeatUpdateChan := client.HeartbeatUpdateChan
outputBufferTimeoutUpdateChan := client.OutputBufferTimeoutUpdateChan
sampleRateUpdateChan := client.SampleRateUpdateChan
flushed := true

for {
Expand Down Expand Up @@ -245,10 +248,18 @@ func (p *ProtocolV2) messagePump(client *ClientV2) {
if err != nil {
goto exit
}
case sampleRate = <-sampleRateUpdateChan:
sampleRateUpdateChan = nil
case msg, ok := <-clientMsgChan:
if !ok {
goto exit
}

// if we are sampling, do so here
if sampleRate > 0 && rand.Int31n(100) > sampleRate {
continue
}

subChannel.StartInFlightTimeout(msg, client.ID)
client.SendingMessage()
err = p.SendMessage(client, msg, &buf)
Expand Down Expand Up @@ -335,6 +346,7 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error)
DeflateLevel int `json:"deflate_level"`
MaxDeflateLevel int `json:"max_deflate_level"`
Snappy bool `json:"snappy"`
SampleRate int32 `json:"sample_rate"`
}{
MaxRdyCount: p.context.nsqd.options.maxRdyCount,
Version: util.BINARY_VERSION,
Expand All @@ -345,6 +357,7 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error)
DeflateLevel: deflateLevel,
MaxDeflateLevel: p.context.nsqd.options.maxDeflateLevel,
Snappy: snappy,
SampleRate: client.SampleRate,
})
if err != nil {
panic("should never happen")
Expand Down
57 changes: 57 additions & 0 deletions nsqd/protocol_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"io/ioutil"
"log"
"math"
"math/rand"
"net"
"os"
"runtime"
Expand Down Expand Up @@ -900,6 +901,62 @@ func TestTLSDeflate(t *testing.T) {
assert.Equal(t, data, []byte("OK"))
}

func TestSampling(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stdout)

rand.Seed(time.Now().UTC().UnixNano())

num := 3000

*verbose = true
options := NewNsqdOptions()
options.maxRdyCount = int64(num)
tcpAddr, _, nsqd := mustStartNSQd(options)
defer nsqd.Exit()

conn, err := mustConnectNSQd(tcpAddr)
assert.Equal(t, err, nil)

data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"sample_rate": int32(42)})
r := struct {
SampleRate int32 `json:"sample_rate"`
}{}
err = json.Unmarshal(data, &r)
assert.Equal(t, err, nil)
assert.Equal(t, r.SampleRate, int32(42))

topicName := "test_sampling" + strconv.Itoa(int(time.Now().Unix()))
topic := nsqd.GetTopic(topicName)
for i := 0; i < num; i++ {
msg := nsq.NewMessage(<-nsqd.idChan, []byte("test body"))
topic.PutMessage(msg)
}
channel := topic.GetChannel("ch")

// let the topic drain into the channel
time.Sleep(50 * time.Millisecond)

sub(t, conn, topicName, "ch")
err = nsq.Ready(num).Write(conn)
assert.Equal(t, err, nil)

doneChan := make(chan int)
go func() {
for {
if channel.Depth() == 0 {
close(doneChan)
return
}
time.Sleep(5 * time.Millisecond)
}
}()
<-doneChan
// within 3%
assert.Equal(t, len(channel.inFlightMessages) <= int(float64(num)*0.45), true)
assert.Equal(t, len(channel.inFlightMessages) >= int(float64(num)*0.39), true)
}

func TestTLSSnappy(t *testing.T) {
log.SetOutput(ioutil.Discard)
defer log.SetOutput(os.Stdout)
Expand Down
1 change: 1 addition & 0 deletions nsqd/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ type ClientStats struct {
FinishCount uint64 `json:"finish_count"`
RequeueCount uint64 `json:"requeue_count"`
ConnectTime int64 `json:"connect_ts"`
SampleRate int32 `json:"sample_rate"`
}

type Topics []*Topic
Expand Down

0 comments on commit deabf26

Please sign in to comment.