From c5d2cc2d6b79c02379826f4478605691bce767cc Mon Sep 17 00:00:00 2001 From: Matt Reiferson Date: Wed, 10 Jul 2013 09:31:09 -0400 Subject: [PATCH 1/4] nsqd: add deflate compression negotiation: * add --deflate and --max-deflate-level flag * add feature negotiation for nsqd clients --- nsqd/client_v2.go | 34 ++++++++++++++++-- nsqd/main.go | 10 ++++++ nsqd/nsqd.go | 7 ++++ nsqd/protocol_v2.go | 48 +++++++++++++++++++------ nsqd/protocol_v2_test.go | 78 ++++++++++++++++++++++++++++++++++++++++ test.sh | 4 +-- 6 files changed, 166 insertions(+), 15 deletions(-) diff --git a/nsqd/client_v2.go b/nsqd/client_v2.go index 6452f6afd..92b00536a 100644 --- a/nsqd/client_v2.go +++ b/nsqd/client_v2.go @@ -2,6 +2,7 @@ package main import ( "bufio" + "compress/flate" "crypto/tls" "errors" "fmt" @@ -21,6 +22,8 @@ type IdentifyDataV2 struct { OutputBufferTimeout int `json:"output_buffer_timeout"` FeatureNegotiation bool `json:"feature_negotiation"` TLSv1 bool `json:"tls_v1"` + Deflate bool `json:"deflate"` + DeflateLevel int `json:"deflate_level"` } type ClientV2 struct { @@ -35,9 +38,10 @@ type ClientV2 struct { net.Conn sync.Mutex - ID int64 - context *Context - tlsConn net.Conn + ID int64 + context *Context + tlsConn net.Conn + flateWriter *flate.Writer // buffered IO Reader *bufio.Reader @@ -315,10 +319,34 @@ func (c *ClientV2) UpgradeTLS() error { return nil } +func (c *ClientV2) UpgradeDeflate(level int) error { + c.Lock() + defer c.Unlock() + + conn := c.Conn + if c.tlsConn != nil { + conn = c.tlsConn + } + + fr := flate.NewReader(conn) + c.Reader = bufio.NewReaderSize(fr, 16*1024) + + fw, _ := flate.NewWriter(conn, level) + c.flateWriter = fw + c.Writer = bufio.NewWriterSize(fw, c.OutputBufferSize) + + return nil +} + func (c *ClientV2) Flush() error { err := c.Writer.Flush() if err != nil { return err } + + if c.flateWriter != nil { + return c.flateWriter.Flush() + } + return nil } diff --git a/nsqd/main.go b/nsqd/main.go index 4354ff7cc..9d99a82ff 100644 --- a/nsqd/main.go +++ b/nsqd/main.go @@ -54,6 +54,10 @@ var ( // TLS config tlsCert = flag.String("tls-cert", "", "path to certificate file") tlsKey = flag.String("tls-key", "", "path to private key file") + + // compression + deflateEnabled = flag.Bool("deflate", true, "enable deflate feature negotiation (client compression)") + maxDeflateLevel = flag.Int("max-deflate-level", 6, "max deflate compression level a client can negotiate (> values == > nsqd CPU usage)") ) func init() { @@ -107,6 +111,10 @@ func main() { // flagToDuration will fatally error if it is invalid msgTimeoutDuration := flagToDuration(*msgTimeout, time.Millisecond, "--msg-timeout") + if *maxDeflateLevel < 1 || *maxDeflateLevel > 9 { + log.Fatalf("--max-deflate-level must be [1,9]") + } + options := NewNsqdOptions() options.maxRdyCount = *maxRdyCount options.maxMessageSize = *maxMessageSize @@ -124,6 +132,8 @@ func main() { options.maxOutputBufferTimeout = *maxOutputBufferTimeout options.tlsCert = *tlsCert options.tlsKey = *tlsKey + options.deflateEnabled = *deflateEnabled + options.maxDeflateLevel = *maxDeflateLevel if *statsdAddress != "" { // flagToDuration will fatally error if it is invalid diff --git a/nsqd/nsqd.go b/nsqd/nsqd.go index 765724f97..6ab6f68bd 100644 --- a/nsqd/nsqd.go +++ b/nsqd/nsqd.go @@ -77,6 +77,10 @@ type nsqdOptions struct { // TLS config tlsCert string tlsKey string + + // deflate + deflateEnabled bool + maxDeflateLevel int } func NewNsqdOptions() *nsqdOptions { @@ -106,6 +110,9 @@ func NewNsqdOptions() *nsqdOptions { tlsCert: "", tlsKey: "", + + deflateEnabled: true, + maxDeflateLevel: -1, } } diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index 20e7f2b63..b7a8cc9fb 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -9,6 +9,7 @@ import ( "github.com/bitly/nsq/util" "io" "log" + "math" "net" "sync/atomic" "time" @@ -313,19 +314,33 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) } tlsv1 := p.context.nsqd.tlsConfig != nil && identifyData.TLSv1 + deflate := p.context.nsqd.options.deflateEnabled && identifyData.Deflate + deflateLevel := 0 + if deflate { + if identifyData.DeflateLevel <= 0 { + deflateLevel = 6 + } + deflateLevel = int(math.Min(float64(deflateLevel), float64(p.context.nsqd.options.maxDeflateLevel))) + } resp, err := json.Marshal(struct { - MaxRdyCount int64 `json:"max_rdy_count"` - Version string `json:"version"` - MaxMsgTimeout int64 `json:"max_msg_timeout"` - MsgTimeout int64 `json:"msg_timeout"` - TLSv1 bool `json:"tls_v1"` + MaxRdyCount int64 `json:"max_rdy_count"` + Version string `json:"version"` + MaxMsgTimeout int64 `json:"max_msg_timeout"` + MsgTimeout int64 `json:"msg_timeout"` + TLSv1 bool `json:"tls_v1"` + Deflate bool `json:"deflate"` + DeflateLevel int `json:"deflate_level"` + MaxDeflateLevel int `json:"max_deflate_level"` }{ - MaxRdyCount: p.context.nsqd.options.maxRdyCount, - Version: util.BINARY_VERSION, - MaxMsgTimeout: int64(p.context.nsqd.options.maxMsgTimeout / time.Millisecond), - MsgTimeout: int64(p.context.nsqd.options.msgTimeout / time.Millisecond), - TLSv1: tlsv1, + MaxRdyCount: p.context.nsqd.options.maxRdyCount, + Version: util.BINARY_VERSION, + MaxMsgTimeout: int64(p.context.nsqd.options.maxMsgTimeout / time.Millisecond), + MsgTimeout: int64(p.context.nsqd.options.msgTimeout / time.Millisecond), + TLSv1: tlsv1, + Deflate: deflate, + DeflateLevel: deflateLevel, + MaxDeflateLevel: p.context.nsqd.options.maxDeflateLevel, }) if err != nil { panic("should never happen") @@ -349,6 +364,19 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) } } + if deflate { + log.Printf("PROTOCOL(V2): [%s] upgrading connection to deflate", client) + err = client.UpgradeDeflate(deflateLevel) + if err != nil { + return nil, util.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) + } + + err = p.Send(client, nsq.FrameTypeResponse, okBytes) + if err != nil { + return nil, util.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) + } + } + return nil, nil } diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index fb37c5c54..cf874c9e4 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -3,6 +3,7 @@ package main import ( "bufio" "bytes" + "compress/flate" "crypto/tls" "encoding/json" "fmt" @@ -764,6 +765,83 @@ func TestTLS(t *testing.T) { assert.Equal(t, data, []byte("OK")) } +func TestDeflate(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + *verbose = true + options := NewNsqdOptions() + options.deflateEnabled = true + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"deflate": true}) + r := struct { + Deflate bool `json:"deflate"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.Deflate, true) + + compressConn := flate.NewReader(conn) + resp, _ := nsq.ReadResponse(compressConn) + frameType, data, _ := nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) +} + +func TestTLSDeflate(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + *verbose = true + options := NewNsqdOptions() + options.deflateEnabled = true + options.tlsCert = "./test/cert.pem" + options.tlsKey = "./test/key.pem" + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"tls_v1": true, "deflate": true}) + r := struct { + TLSv1 bool `json:"tls_v1"` + Deflate bool `json:"deflate"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.TLSv1, true) + assert.Equal(t, r.Deflate, true) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + tlsConn := tls.Client(conn, tlsConfig) + + err = tlsConn.Handshake() + assert.Equal(t, err, nil) + + resp, _ := nsq.ReadResponse(tlsConn) + frameType, data, _ := nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) + + compressConn := flate.NewReader(tlsConn) + + resp, _ = nsq.ReadResponse(compressConn) + frameType, data, _ = nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) +} + func BenchmarkProtocolV2Exec(b *testing.B) { b.StopTimer() log.SetOutput(ioutil.Discard) diff --git a/test.sh b/test.sh index 0e80ea308..8ba439bcd 100755 --- a/test.sh +++ b/test.sh @@ -21,8 +21,8 @@ popd >/dev/null pushd nsqd >/dev/null go build rm -f *.dat -echo "starting nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem" -./nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem >/dev/null 2>&1 & +echo "starting nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem --deflate" +./nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem --deflate >/dev/null 2>&1 & NSQD_PID=$! popd >/dev/null From d3e525e347ad52d1ff7cea6eadda37f63a5bef6d Mon Sep 17 00:00:00 2001 From: Matt Reiferson Date: Fri, 30 Aug 2013 20:53:14 -0400 Subject: [PATCH 2/4] nsqd: add framed snappy compression --- .travis.yml | 5 +- dist.sh | 3 +- nsqd/client_v2.go | 27 +++++++-- nsqd/main.go | 2 + nsqd/nsqd.go | 4 +- nsqd/protocol_v2.go | 20 +++++++ nsqd/protocol_v2_test.go | 120 ++++++++++++++++++++++++++++++++++++--- test.sh | 5 +- 8 files changed, 168 insertions(+), 18 deletions(-) diff --git a/.travis.yml b/.travis.yml index aca21f34f..8cbdff09a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,10 +6,11 @@ env: - GOARCH=amd64 - GOARCH=386 install: - - go get github.com/bmizerany/assert - go get github.com/bitly/go-nsq - - go get github.com/bitly/go-hostpool - go get github.com/bitly/go-simplejson + - go get github.com/mreiferson/go-snappystream + - go get github.com/bitly/go-hostpool + - go get github.com/bmizerany/assert script: - pushd $TRAVIS_BUILD_DIR - ./test.sh diff --git a/dist.sh b/dist.sh index d221a6a8f..4bbc1db7c 100755 --- a/dist.sh +++ b/dist.sh @@ -20,9 +20,10 @@ git archive HEAD | tar -x -C $TMPGOPATH/src/github.com/bitly/nsq export GOPATH="$TMPGOPATH:$GOROOT" echo "... getting dependencies" +go get -v github.com/bitly/go-nsq go get -v github.com/bitly/go-simplejson +go get -v github.com/mreiferson/go-snappystream go get -v github.com/bitly/go-hostpool -go get -v github.com/bitly/go-nsq go get -v github.com/bmizerany/assert pushd $TMPGOPATH/src/github.com/bitly/nsq diff --git a/nsqd/client_v2.go b/nsqd/client_v2.go index 92b00536a..55e01c695 100644 --- a/nsqd/client_v2.go +++ b/nsqd/client_v2.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "github.com/bitly/go-nsq" + "github.com/mreiferson/go-snappystream" "log" "net" "sync" @@ -24,6 +25,7 @@ type IdentifyDataV2 struct { TLSv1 bool `json:"tls_v1"` Deflate bool `json:"deflate"` DeflateLevel int `json:"deflate_level"` + Snappy bool `json:"snappy"` } type ClientV2 struct { @@ -38,12 +40,12 @@ type ClientV2 struct { net.Conn sync.Mutex - ID int64 - context *Context - tlsConn net.Conn - flateWriter *flate.Writer + ID int64 + context *Context + + tlsConn net.Conn + flateWriter *flate.Writer - // buffered IO Reader *bufio.Reader Writer *bufio.Writer OutputBufferSize int @@ -338,6 +340,21 @@ func (c *ClientV2) UpgradeDeflate(level int) error { return nil } +func (c *ClientV2) UpgradeSnappy() error { + c.Lock() + defer c.Unlock() + + conn := c.Conn + if c.tlsConn != nil { + conn = c.tlsConn + } + + c.Reader = bufio.NewReaderSize(snappystream.NewReader(conn, snappystream.SkipVerifyChecksum), DefaultBufferSize) + c.Writer = bufio.NewWriterSize(snappystream.NewWriter(conn), c.OutputBufferSize) + + return nil +} + func (c *ClientV2) Flush() error { err := c.Writer.Flush() if err != nil { diff --git a/nsqd/main.go b/nsqd/main.go index 9d99a82ff..dc81871f3 100644 --- a/nsqd/main.go +++ b/nsqd/main.go @@ -58,6 +58,7 @@ var ( // compression deflateEnabled = flag.Bool("deflate", true, "enable deflate feature negotiation (client compression)") maxDeflateLevel = flag.Int("max-deflate-level", 6, "max deflate compression level a client can negotiate (> values == > nsqd CPU usage)") + snappyEnabled = flag.Bool("snappy", true, "enable snappy feature negotiation (client compression)") ) func init() { @@ -134,6 +135,7 @@ func main() { options.tlsKey = *tlsKey options.deflateEnabled = *deflateEnabled options.maxDeflateLevel = *maxDeflateLevel + options.snappyEnabled = *snappyEnabled if *statsdAddress != "" { // flagToDuration will fatally error if it is invalid diff --git a/nsqd/nsqd.go b/nsqd/nsqd.go index 6ab6f68bd..862c38434 100644 --- a/nsqd/nsqd.go +++ b/nsqd/nsqd.go @@ -78,9 +78,10 @@ type nsqdOptions struct { tlsCert string tlsKey string - // deflate + // compression deflateEnabled bool maxDeflateLevel int + snappyEnabled bool } func NewNsqdOptions() *nsqdOptions { @@ -113,6 +114,7 @@ func NewNsqdOptions() *nsqdOptions { deflateEnabled: true, maxDeflateLevel: -1, + snappyEnabled: true, } } diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index b7a8cc9fb..b84125de5 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -322,6 +322,11 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) } deflateLevel = int(math.Min(float64(deflateLevel), float64(p.context.nsqd.options.maxDeflateLevel))) } + snappy := p.context.nsqd.options.snappyEnabled && identifyData.Snappy + + if deflate && snappy { + return nil, util.NewFatalClientErr(nil, "E_IDENTIFY_FAILED", "cannot enable both deflate and snappy compression") + } resp, err := json.Marshal(struct { MaxRdyCount int64 `json:"max_rdy_count"` @@ -332,6 +337,7 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) Deflate bool `json:"deflate"` DeflateLevel int `json:"deflate_level"` MaxDeflateLevel int `json:"max_deflate_level"` + Snappy bool `json:"snappy"` }{ MaxRdyCount: p.context.nsqd.options.maxRdyCount, Version: util.BINARY_VERSION, @@ -341,6 +347,7 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) Deflate: deflate, DeflateLevel: deflateLevel, MaxDeflateLevel: p.context.nsqd.options.maxDeflateLevel, + Snappy: snappy, }) if err != nil { panic("should never happen") @@ -364,6 +371,19 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) } } + if snappy { + log.Printf("PROTOCOL(V2): [%s] upgrading connection to snappy", client) + err = client.UpgradeSnappy() + if err != nil { + return nil, util.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) + } + + err = p.Send(client, nsq.FrameTypeResponse, okBytes) + if err != nil { + return nil, util.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) + } + } + if deflate { log.Printf("PROTOCOL(V2): [%s] upgrading connection to deflate", client) err = client.UpgradeDeflate(deflateLevel) diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index cf874c9e4..7adeceb06 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -9,6 +9,8 @@ import ( "fmt" "github.com/bitly/go-nsq" "github.com/bmizerany/assert" + "github.com/mreiferson/go-snappystream" + "io" "io/ioutil" "log" "math" @@ -40,7 +42,7 @@ func mustConnectNSQd(tcpAddr *net.TCPAddr) (net.Conn, error) { return conn, nil } -func identify(t *testing.T, conn net.Conn) { +func identify(t *testing.T, conn io.ReadWriter) { ci := make(map[string]interface{}) ci["short_id"] = "test" ci["long_id"] = "test" @@ -50,7 +52,7 @@ func identify(t *testing.T, conn net.Conn) { readValidate(t, conn, nsq.FrameTypeResponse, "OK") } -func identifyHeartbeatInterval(t *testing.T, conn net.Conn, interval int, f int32, d string) { +func identifyHeartbeatInterval(t *testing.T, conn io.ReadWriter, interval int, f int32, d string) { ci := make(map[string]interface{}) ci["short_id"] = "test" ci["long_id"] = "test" @@ -61,7 +63,7 @@ func identifyHeartbeatInterval(t *testing.T, conn net.Conn, interval int, f int3 readValidate(t, conn, f, d) } -func identifyFeatureNegotiation(t *testing.T, conn net.Conn, extra map[string]interface{}) []byte { +func identifyFeatureNegotiation(t *testing.T, conn io.ReadWriter, extra map[string]interface{}) []byte { ci := make(map[string]interface{}) ci["short_id"] = "test" ci["long_id"] = "test" @@ -82,7 +84,7 @@ func identifyFeatureNegotiation(t *testing.T, conn net.Conn, extra map[string]in return data } -func identifyOutputBuffering(t *testing.T, conn net.Conn, size int, timeout int, f int32, d string) { +func identifyOutputBuffering(t *testing.T, conn io.ReadWriter, size int, timeout int, f int32, d string) { ci := make(map[string]interface{}) ci["short_id"] = "test" ci["long_id"] = "test" @@ -94,13 +96,13 @@ func identifyOutputBuffering(t *testing.T, conn net.Conn, size int, timeout int, readValidate(t, conn, f, d) } -func sub(t *testing.T, conn net.Conn, topicName string, channelName string) { +func sub(t *testing.T, conn io.ReadWriter, topicName string, channelName string) { err := nsq.Subscribe(topicName, channelName).Write(conn) assert.Equal(t, err, nil) readValidate(t, conn, nsq.FrameTypeResponse, "OK") } -func subFail(t *testing.T, conn net.Conn, topicName string, channelName string) { +func subFail(t *testing.T, conn io.ReadWriter, topicName string, channelName string) { err := nsq.Subscribe(topicName, channelName).Write(conn) assert.Equal(t, err, nil) resp, err := nsq.ReadResponse(conn) @@ -108,7 +110,7 @@ func subFail(t *testing.T, conn net.Conn, topicName string, channelName string) assert.Equal(t, frameType, nsq.FrameTypeError) } -func readValidate(t *testing.T, conn net.Conn, f int32, d string) { +func readValidate(t *testing.T, conn io.Reader, f int32, d string) { resp, err := nsq.ReadResponse(conn) assert.Equal(t, err, nil) frameType, data, err := nsq.UnpackResponse(resp) @@ -794,6 +796,62 @@ func TestDeflate(t *testing.T) { assert.Equal(t, data, []byte("OK")) } +type readWriter struct { + io.Reader + io.Writer +} + +func TestSnappy(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + *verbose = true + options := NewNsqdOptions() + options.snappyEnabled = true + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"snappy": true}) + r := struct { + Snappy bool `json:"snappy"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.Snappy, true) + + compressConn := snappystream.NewReader(conn, snappystream.SkipVerifyChecksum) + resp, _ := nsq.ReadResponse(compressConn) + frameType, data, _ := nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) + + msgBody := make([]byte, 128000) + w := snappystream.NewWriter(conn) + + rw := readWriter{compressConn, w} + + topicName := "test_snappy" + strconv.Itoa(int(time.Now().Unix())) + sub(t, rw, topicName, "ch") + + err = nsq.Ready(1).Write(rw) + assert.Equal(t, err, nil) + + topic := nsqd.GetTopic(topicName) + msg := nsq.NewMessage(<-nsqd.idChan, msgBody) + topic.PutMessage(msg) + + resp, _ = nsq.ReadResponse(compressConn) + frameType, data, _ = nsq.UnpackResponse(resp) + msgOut, _ := nsq.DecodeMessage(data) + assert.Equal(t, frameType, nsq.FrameTypeMessage) + assert.Equal(t, msgOut.Id, msg.Id) + assert.Equal(t, msgOut.Body, msg.Body) +} + func TestTLSDeflate(t *testing.T) { log.SetOutput(ioutil.Discard) defer log.SetOutput(os.Stdout) @@ -842,6 +900,54 @@ func TestTLSDeflate(t *testing.T) { assert.Equal(t, data, []byte("OK")) } +func TestTLSSnappy(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + *verbose = true + options := NewNsqdOptions() + options.snappyEnabled = true + options.tlsCert = "./test/cert.pem" + options.tlsKey = "./test/key.pem" + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"tls_v1": true, "snappy": true}) + r := struct { + TLSv1 bool `json:"tls_v1"` + Snappy bool `json:"snappy"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.TLSv1, true) + assert.Equal(t, r.Snappy, true) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + tlsConn := tls.Client(conn, tlsConfig) + + err = tlsConn.Handshake() + assert.Equal(t, err, nil) + + resp, _ := nsq.ReadResponse(tlsConn) + frameType, data, _ := nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) + + compressConn := snappystream.NewReader(tlsConn, snappystream.SkipVerifyChecksum) + + resp, _ = nsq.ReadResponse(compressConn) + frameType, data, _ = nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) +} + func BenchmarkProtocolV2Exec(b *testing.B) { b.StopTimer() log.SetOutput(ioutil.Discard) diff --git a/test.sh b/test.sh index 8ba439bcd..78d07fc11 100755 --- a/test.sh +++ b/test.sh @@ -21,8 +21,9 @@ popd >/dev/null pushd nsqd >/dev/null go build rm -f *.dat -echo "starting nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem --deflate" -./nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem --deflate >/dev/null 2>&1 & +cmd="./nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem" +echo "starting $cmd" +$cmd >/dev/null 2>&1 & NSQD_PID=$! popd >/dev/null From 04f324d658a2149e4118ce8d3609f0c159a73306 Mon Sep 17 00:00:00 2001 From: Matt Reiferson Date: Fri, 30 Aug 2013 20:56:33 -0400 Subject: [PATCH 3/4] nsqd: add/use DefaultBufferSize --- nsqd/client_v2.go | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/nsqd/client_v2.go b/nsqd/client_v2.go index 55e01c695..3de7d7197 100644 --- a/nsqd/client_v2.go +++ b/nsqd/client_v2.go @@ -15,6 +15,8 @@ import ( "time" ) +const DefaultBufferSize = 16 * 1024 + type IdentifyDataV2 struct { ShortId string `json:"short_id"` LongId string `json:"long_id"` @@ -79,12 +81,13 @@ func NewClientV2(id int64, conn net.Conn, context *Context) *ClientV2 { c := &ClientV2{ ID: id, - Conn: conn, context: context, - Reader: bufio.NewReaderSize(conn, 16*1024), - Writer: bufio.NewWriterSize(conn, 16*1024), - OutputBufferSize: 16 * 1024, + Conn: conn, + + Reader: bufio.NewReaderSize(conn, DefaultBufferSize), + Writer: bufio.NewWriterSize(conn, DefaultBufferSize), + OutputBufferSize: DefaultBufferSize, OutputBufferTimeout: time.NewTicker(250 * time.Millisecond), OutputBufferTimeoutUpdateChan: make(chan time.Duration, 1), @@ -315,7 +318,7 @@ func (c *ClientV2) UpgradeTLS() error { } c.tlsConn = tlsConn - c.Reader = bufio.NewReaderSize(c.tlsConn, 16*1024) + c.Reader = bufio.NewReaderSize(c.tlsConn, DefaultBufferSize) c.Writer = bufio.NewWriterSize(c.tlsConn, c.OutputBufferSize) return nil @@ -330,8 +333,7 @@ func (c *ClientV2) UpgradeDeflate(level int) error { conn = c.tlsConn } - fr := flate.NewReader(conn) - c.Reader = bufio.NewReaderSize(fr, 16*1024) + c.Reader = bufio.NewReaderSize(flate.NewReader(conn), DefaultBufferSize) fw, _ := flate.NewWriter(conn, level) c.flateWriter = fw From a237e04386ab8f576cc7de265107e47613a3ad3e Mon Sep 17 00:00:00 2001 From: Matt Reiferson Date: Sun, 1 Sep 2013 11:27:23 -0400 Subject: [PATCH 4/4] nsqd: re-org, move flush into client * drop defer in fast path --- nsqd/client_v2.go | 21 +++++++++++++++------ nsqd/protocol_v2.go | 20 +++++++++----------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/nsqd/client_v2.go b/nsqd/client_v2.go index 3de7d7197..7b971e5ab 100644 --- a/nsqd/client_v2.go +++ b/nsqd/client_v2.go @@ -39,17 +39,23 @@ type ClientV2 struct { FinishCount uint64 RequeueCount uint64 - net.Conn sync.Mutex ID int64 context *Context - tlsConn net.Conn + // original connection + net.Conn + + // connections based on negotiated features + tlsConn *tls.Conn flateWriter *flate.Writer - Reader *bufio.Reader - Writer *bufio.Writer + // reading/writing interfaces + Reader *bufio.Reader + Writer *bufio.Writer + + // output buffering OutputBufferSize int OutputBufferTimeout *time.Ticker OutputBufferTimeoutUpdateChan chan time.Duration @@ -85,8 +91,9 @@ func NewClientV2(id int64, conn net.Conn, context *Context) *ClientV2 { Conn: conn, - Reader: bufio.NewReaderSize(conn, DefaultBufferSize), - Writer: bufio.NewWriterSize(conn, DefaultBufferSize), + Reader: bufio.NewReaderSize(conn, DefaultBufferSize), + Writer: bufio.NewWriterSize(conn, DefaultBufferSize), + OutputBufferSize: DefaultBufferSize, OutputBufferTimeout: time.NewTicker(250 * time.Millisecond), OutputBufferTimeoutUpdateChan: make(chan time.Duration, 1), @@ -358,6 +365,8 @@ func (c *ClientV2) UpgradeSnappy() error { } func (c *ClientV2) Flush() error { + c.SetWriteDeadline(time.Now().Add(time.Second)) + err := c.Writer.Flush() if err != nil { return err diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index b84125de5..0074294d8 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -121,11 +121,11 @@ func (p *ProtocolV2) SendMessage(client *ClientV2, msg *nsq.Message, buf *bytes. func (p *ProtocolV2) Send(client *ClientV2, frameType int32, data []byte) error { client.Lock() - defer client.Unlock() client.SetWriteDeadline(time.Now().Add(time.Second)) _, err := util.SendFramedResponse(client.Writer, frameType, data) if err != nil { + client.Unlock() return err } @@ -133,15 +133,9 @@ func (p *ProtocolV2) Send(client *ClientV2, frameType int32, data []byte) error err = client.Flush() } - return err -} - -func (p *ProtocolV2) Flush(client *ClientV2) error { - client.Lock() - defer client.Unlock() + client.Unlock() - client.SetWriteDeadline(time.Now().Add(time.Second)) - return client.Flush() + return err } func (p *ProtocolV2) Exec(client *ClientV2, params [][]byte) ([]byte, error) { @@ -197,7 +191,9 @@ func (p *ProtocolV2) messagePump(client *ClientV2) { clientMsgChan = nil flusherChan = nil // force flush - err = p.Flush(client) + client.Lock() + err = client.Flush() + client.Unlock() if err != nil { goto exit } @@ -219,7 +215,9 @@ func (p *ProtocolV2) messagePump(client *ClientV2) { // if this case wins, we're either starved // or we won the race between other channels... // in either case, force flush - err = p.Flush(client) + client.Lock() + err = client.Flush() + client.Unlock() if err != nil { goto exit }