diff --git a/.travis.yml b/.travis.yml index aca21f34f..8cbdff09a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -6,10 +6,11 @@ env: - GOARCH=amd64 - GOARCH=386 install: - - go get github.com/bmizerany/assert - go get github.com/bitly/go-nsq - - go get github.com/bitly/go-hostpool - go get github.com/bitly/go-simplejson + - go get github.com/mreiferson/go-snappystream + - go get github.com/bitly/go-hostpool + - go get github.com/bmizerany/assert script: - pushd $TRAVIS_BUILD_DIR - ./test.sh diff --git a/dist.sh b/dist.sh index d221a6a8f..4bbc1db7c 100755 --- a/dist.sh +++ b/dist.sh @@ -20,9 +20,10 @@ git archive HEAD | tar -x -C $TMPGOPATH/src/github.com/bitly/nsq export GOPATH="$TMPGOPATH:$GOROOT" echo "... getting dependencies" +go get -v github.com/bitly/go-nsq go get -v github.com/bitly/go-simplejson +go get -v github.com/mreiferson/go-snappystream go get -v github.com/bitly/go-hostpool -go get -v github.com/bitly/go-nsq go get -v github.com/bmizerany/assert pushd $TMPGOPATH/src/github.com/bitly/nsq diff --git a/nsqd/client_v2.go b/nsqd/client_v2.go index 6452f6afd..7b971e5ab 100644 --- a/nsqd/client_v2.go +++ b/nsqd/client_v2.go @@ -2,10 +2,12 @@ package main import ( "bufio" + "compress/flate" "crypto/tls" "errors" "fmt" "github.com/bitly/go-nsq" + "github.com/mreiferson/go-snappystream" "log" "net" "sync" @@ -13,6 +15,8 @@ import ( "time" ) +const DefaultBufferSize = 16 * 1024 + type IdentifyDataV2 struct { ShortId string `json:"short_id"` LongId string `json:"long_id"` @@ -21,6 +25,9 @@ type IdentifyDataV2 struct { OutputBufferTimeout int `json:"output_buffer_timeout"` FeatureNegotiation bool `json:"feature_negotiation"` TLSv1 bool `json:"tls_v1"` + Deflate bool `json:"deflate"` + DeflateLevel int `json:"deflate_level"` + Snappy bool `json:"snappy"` } type ClientV2 struct { @@ -32,16 +39,23 @@ type ClientV2 struct { FinishCount uint64 RequeueCount uint64 - net.Conn sync.Mutex ID int64 context *Context - tlsConn net.Conn - // buffered IO - Reader *bufio.Reader - Writer *bufio.Writer + // original connection + net.Conn + + // connections based on negotiated features + tlsConn *tls.Conn + flateWriter *flate.Writer + + // reading/writing interfaces + Reader *bufio.Reader + Writer *bufio.Writer + + // output buffering OutputBufferSize int OutputBufferTimeout *time.Ticker OutputBufferTimeoutUpdateChan chan time.Duration @@ -73,12 +87,14 @@ func NewClientV2(id int64, conn net.Conn, context *Context) *ClientV2 { c := &ClientV2{ ID: id, - Conn: conn, context: context, - Reader: bufio.NewReaderSize(conn, 16*1024), - Writer: bufio.NewWriterSize(conn, 16*1024), - OutputBufferSize: 16 * 1024, + Conn: conn, + + Reader: bufio.NewReaderSize(conn, DefaultBufferSize), + Writer: bufio.NewWriterSize(conn, DefaultBufferSize), + + OutputBufferSize: DefaultBufferSize, OutputBufferTimeout: time.NewTicker(250 * time.Millisecond), OutputBufferTimeoutUpdateChan: make(chan time.Duration, 1), @@ -309,16 +325,56 @@ func (c *ClientV2) UpgradeTLS() error { } c.tlsConn = tlsConn - c.Reader = bufio.NewReaderSize(c.tlsConn, 16*1024) + c.Reader = bufio.NewReaderSize(c.tlsConn, DefaultBufferSize) c.Writer = bufio.NewWriterSize(c.tlsConn, c.OutputBufferSize) return nil } +func (c *ClientV2) UpgradeDeflate(level int) error { + c.Lock() + defer c.Unlock() + + conn := c.Conn + if c.tlsConn != nil { + conn = c.tlsConn + } + + c.Reader = bufio.NewReaderSize(flate.NewReader(conn), DefaultBufferSize) + + fw, _ := flate.NewWriter(conn, level) + c.flateWriter = fw + c.Writer = bufio.NewWriterSize(fw, c.OutputBufferSize) + + return nil +} + +func (c *ClientV2) UpgradeSnappy() error { + c.Lock() + defer c.Unlock() + + conn := c.Conn + if c.tlsConn != nil { + conn = c.tlsConn + } + + c.Reader = bufio.NewReaderSize(snappystream.NewReader(conn, snappystream.SkipVerifyChecksum), DefaultBufferSize) + c.Writer = bufio.NewWriterSize(snappystream.NewWriter(conn), c.OutputBufferSize) + + return nil +} + func (c *ClientV2) Flush() error { + c.SetWriteDeadline(time.Now().Add(time.Second)) + err := c.Writer.Flush() if err != nil { return err } + + if c.flateWriter != nil { + return c.flateWriter.Flush() + } + return nil } diff --git a/nsqd/main.go b/nsqd/main.go index 4354ff7cc..dc81871f3 100644 --- a/nsqd/main.go +++ b/nsqd/main.go @@ -54,6 +54,11 @@ var ( // TLS config tlsCert = flag.String("tls-cert", "", "path to certificate file") tlsKey = flag.String("tls-key", "", "path to private key file") + + // compression + deflateEnabled = flag.Bool("deflate", true, "enable deflate feature negotiation (client compression)") + maxDeflateLevel = flag.Int("max-deflate-level", 6, "max deflate compression level a client can negotiate (> values == > nsqd CPU usage)") + snappyEnabled = flag.Bool("snappy", true, "enable snappy feature negotiation (client compression)") ) func init() { @@ -107,6 +112,10 @@ func main() { // flagToDuration will fatally error if it is invalid msgTimeoutDuration := flagToDuration(*msgTimeout, time.Millisecond, "--msg-timeout") + if *maxDeflateLevel < 1 || *maxDeflateLevel > 9 { + log.Fatalf("--max-deflate-level must be [1,9]") + } + options := NewNsqdOptions() options.maxRdyCount = *maxRdyCount options.maxMessageSize = *maxMessageSize @@ -124,6 +133,9 @@ func main() { options.maxOutputBufferTimeout = *maxOutputBufferTimeout options.tlsCert = *tlsCert options.tlsKey = *tlsKey + options.deflateEnabled = *deflateEnabled + options.maxDeflateLevel = *maxDeflateLevel + options.snappyEnabled = *snappyEnabled if *statsdAddress != "" { // flagToDuration will fatally error if it is invalid diff --git a/nsqd/nsqd.go b/nsqd/nsqd.go index 765724f97..862c38434 100644 --- a/nsqd/nsqd.go +++ b/nsqd/nsqd.go @@ -77,6 +77,11 @@ type nsqdOptions struct { // TLS config tlsCert string tlsKey string + + // compression + deflateEnabled bool + maxDeflateLevel int + snappyEnabled bool } func NewNsqdOptions() *nsqdOptions { @@ -106,6 +111,10 @@ func NewNsqdOptions() *nsqdOptions { tlsCert: "", tlsKey: "", + + deflateEnabled: true, + maxDeflateLevel: -1, + snappyEnabled: true, } } diff --git a/nsqd/protocol_v2.go b/nsqd/protocol_v2.go index 20e7f2b63..0074294d8 100644 --- a/nsqd/protocol_v2.go +++ b/nsqd/protocol_v2.go @@ -9,6 +9,7 @@ import ( "github.com/bitly/nsq/util" "io" "log" + "math" "net" "sync/atomic" "time" @@ -120,11 +121,11 @@ func (p *ProtocolV2) SendMessage(client *ClientV2, msg *nsq.Message, buf *bytes. func (p *ProtocolV2) Send(client *ClientV2, frameType int32, data []byte) error { client.Lock() - defer client.Unlock() client.SetWriteDeadline(time.Now().Add(time.Second)) _, err := util.SendFramedResponse(client.Writer, frameType, data) if err != nil { + client.Unlock() return err } @@ -132,15 +133,9 @@ func (p *ProtocolV2) Send(client *ClientV2, frameType int32, data []byte) error err = client.Flush() } - return err -} - -func (p *ProtocolV2) Flush(client *ClientV2) error { - client.Lock() - defer client.Unlock() + client.Unlock() - client.SetWriteDeadline(time.Now().Add(time.Second)) - return client.Flush() + return err } func (p *ProtocolV2) Exec(client *ClientV2, params [][]byte) ([]byte, error) { @@ -196,7 +191,9 @@ func (p *ProtocolV2) messagePump(client *ClientV2) { clientMsgChan = nil flusherChan = nil // force flush - err = p.Flush(client) + client.Lock() + err = client.Flush() + client.Unlock() if err != nil { goto exit } @@ -218,7 +215,9 @@ func (p *ProtocolV2) messagePump(client *ClientV2) { // if this case wins, we're either starved // or we won the race between other channels... // in either case, force flush - err = p.Flush(client) + client.Lock() + err = client.Flush() + client.Unlock() if err != nil { goto exit } @@ -313,19 +312,40 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) } tlsv1 := p.context.nsqd.tlsConfig != nil && identifyData.TLSv1 + deflate := p.context.nsqd.options.deflateEnabled && identifyData.Deflate + deflateLevel := 0 + if deflate { + if identifyData.DeflateLevel <= 0 { + deflateLevel = 6 + } + deflateLevel = int(math.Min(float64(deflateLevel), float64(p.context.nsqd.options.maxDeflateLevel))) + } + snappy := p.context.nsqd.options.snappyEnabled && identifyData.Snappy + + if deflate && snappy { + return nil, util.NewFatalClientErr(nil, "E_IDENTIFY_FAILED", "cannot enable both deflate and snappy compression") + } resp, err := json.Marshal(struct { - MaxRdyCount int64 `json:"max_rdy_count"` - Version string `json:"version"` - MaxMsgTimeout int64 `json:"max_msg_timeout"` - MsgTimeout int64 `json:"msg_timeout"` - TLSv1 bool `json:"tls_v1"` + MaxRdyCount int64 `json:"max_rdy_count"` + Version string `json:"version"` + MaxMsgTimeout int64 `json:"max_msg_timeout"` + MsgTimeout int64 `json:"msg_timeout"` + TLSv1 bool `json:"tls_v1"` + Deflate bool `json:"deflate"` + DeflateLevel int `json:"deflate_level"` + MaxDeflateLevel int `json:"max_deflate_level"` + Snappy bool `json:"snappy"` }{ - MaxRdyCount: p.context.nsqd.options.maxRdyCount, - Version: util.BINARY_VERSION, - MaxMsgTimeout: int64(p.context.nsqd.options.maxMsgTimeout / time.Millisecond), - MsgTimeout: int64(p.context.nsqd.options.msgTimeout / time.Millisecond), - TLSv1: tlsv1, + MaxRdyCount: p.context.nsqd.options.maxRdyCount, + Version: util.BINARY_VERSION, + MaxMsgTimeout: int64(p.context.nsqd.options.maxMsgTimeout / time.Millisecond), + MsgTimeout: int64(p.context.nsqd.options.msgTimeout / time.Millisecond), + TLSv1: tlsv1, + Deflate: deflate, + DeflateLevel: deflateLevel, + MaxDeflateLevel: p.context.nsqd.options.maxDeflateLevel, + Snappy: snappy, }) if err != nil { panic("should never happen") @@ -349,6 +369,32 @@ func (p *ProtocolV2) IDENTIFY(client *ClientV2, params [][]byte) ([]byte, error) } } + if snappy { + log.Printf("PROTOCOL(V2): [%s] upgrading connection to snappy", client) + err = client.UpgradeSnappy() + if err != nil { + return nil, util.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) + } + + err = p.Send(client, nsq.FrameTypeResponse, okBytes) + if err != nil { + return nil, util.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) + } + } + + if deflate { + log.Printf("PROTOCOL(V2): [%s] upgrading connection to deflate", client) + err = client.UpgradeDeflate(deflateLevel) + if err != nil { + return nil, util.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) + } + + err = p.Send(client, nsq.FrameTypeResponse, okBytes) + if err != nil { + return nil, util.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error()) + } + } + return nil, nil } diff --git a/nsqd/protocol_v2_test.go b/nsqd/protocol_v2_test.go index fb37c5c54..7adeceb06 100644 --- a/nsqd/protocol_v2_test.go +++ b/nsqd/protocol_v2_test.go @@ -3,11 +3,14 @@ package main import ( "bufio" "bytes" + "compress/flate" "crypto/tls" "encoding/json" "fmt" "github.com/bitly/go-nsq" "github.com/bmizerany/assert" + "github.com/mreiferson/go-snappystream" + "io" "io/ioutil" "log" "math" @@ -39,7 +42,7 @@ func mustConnectNSQd(tcpAddr *net.TCPAddr) (net.Conn, error) { return conn, nil } -func identify(t *testing.T, conn net.Conn) { +func identify(t *testing.T, conn io.ReadWriter) { ci := make(map[string]interface{}) ci["short_id"] = "test" ci["long_id"] = "test" @@ -49,7 +52,7 @@ func identify(t *testing.T, conn net.Conn) { readValidate(t, conn, nsq.FrameTypeResponse, "OK") } -func identifyHeartbeatInterval(t *testing.T, conn net.Conn, interval int, f int32, d string) { +func identifyHeartbeatInterval(t *testing.T, conn io.ReadWriter, interval int, f int32, d string) { ci := make(map[string]interface{}) ci["short_id"] = "test" ci["long_id"] = "test" @@ -60,7 +63,7 @@ func identifyHeartbeatInterval(t *testing.T, conn net.Conn, interval int, f int3 readValidate(t, conn, f, d) } -func identifyFeatureNegotiation(t *testing.T, conn net.Conn, extra map[string]interface{}) []byte { +func identifyFeatureNegotiation(t *testing.T, conn io.ReadWriter, extra map[string]interface{}) []byte { ci := make(map[string]interface{}) ci["short_id"] = "test" ci["long_id"] = "test" @@ -81,7 +84,7 @@ func identifyFeatureNegotiation(t *testing.T, conn net.Conn, extra map[string]in return data } -func identifyOutputBuffering(t *testing.T, conn net.Conn, size int, timeout int, f int32, d string) { +func identifyOutputBuffering(t *testing.T, conn io.ReadWriter, size int, timeout int, f int32, d string) { ci := make(map[string]interface{}) ci["short_id"] = "test" ci["long_id"] = "test" @@ -93,13 +96,13 @@ func identifyOutputBuffering(t *testing.T, conn net.Conn, size int, timeout int, readValidate(t, conn, f, d) } -func sub(t *testing.T, conn net.Conn, topicName string, channelName string) { +func sub(t *testing.T, conn io.ReadWriter, topicName string, channelName string) { err := nsq.Subscribe(topicName, channelName).Write(conn) assert.Equal(t, err, nil) readValidate(t, conn, nsq.FrameTypeResponse, "OK") } -func subFail(t *testing.T, conn net.Conn, topicName string, channelName string) { +func subFail(t *testing.T, conn io.ReadWriter, topicName string, channelName string) { err := nsq.Subscribe(topicName, channelName).Write(conn) assert.Equal(t, err, nil) resp, err := nsq.ReadResponse(conn) @@ -107,7 +110,7 @@ func subFail(t *testing.T, conn net.Conn, topicName string, channelName string) assert.Equal(t, frameType, nsq.FrameTypeError) } -func readValidate(t *testing.T, conn net.Conn, f int32, d string) { +func readValidate(t *testing.T, conn io.Reader, f int32, d string) { resp, err := nsq.ReadResponse(conn) assert.Equal(t, err, nil) frameType, data, err := nsq.UnpackResponse(resp) @@ -764,6 +767,187 @@ func TestTLS(t *testing.T) { assert.Equal(t, data, []byte("OK")) } +func TestDeflate(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + *verbose = true + options := NewNsqdOptions() + options.deflateEnabled = true + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"deflate": true}) + r := struct { + Deflate bool `json:"deflate"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.Deflate, true) + + compressConn := flate.NewReader(conn) + resp, _ := nsq.ReadResponse(compressConn) + frameType, data, _ := nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) +} + +type readWriter struct { + io.Reader + io.Writer +} + +func TestSnappy(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + *verbose = true + options := NewNsqdOptions() + options.snappyEnabled = true + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"snappy": true}) + r := struct { + Snappy bool `json:"snappy"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.Snappy, true) + + compressConn := snappystream.NewReader(conn, snappystream.SkipVerifyChecksum) + resp, _ := nsq.ReadResponse(compressConn) + frameType, data, _ := nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) + + msgBody := make([]byte, 128000) + w := snappystream.NewWriter(conn) + + rw := readWriter{compressConn, w} + + topicName := "test_snappy" + strconv.Itoa(int(time.Now().Unix())) + sub(t, rw, topicName, "ch") + + err = nsq.Ready(1).Write(rw) + assert.Equal(t, err, nil) + + topic := nsqd.GetTopic(topicName) + msg := nsq.NewMessage(<-nsqd.idChan, msgBody) + topic.PutMessage(msg) + + resp, _ = nsq.ReadResponse(compressConn) + frameType, data, _ = nsq.UnpackResponse(resp) + msgOut, _ := nsq.DecodeMessage(data) + assert.Equal(t, frameType, nsq.FrameTypeMessage) + assert.Equal(t, msgOut.Id, msg.Id) + assert.Equal(t, msgOut.Body, msg.Body) +} + +func TestTLSDeflate(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + *verbose = true + options := NewNsqdOptions() + options.deflateEnabled = true + options.tlsCert = "./test/cert.pem" + options.tlsKey = "./test/key.pem" + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"tls_v1": true, "deflate": true}) + r := struct { + TLSv1 bool `json:"tls_v1"` + Deflate bool `json:"deflate"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.TLSv1, true) + assert.Equal(t, r.Deflate, true) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + tlsConn := tls.Client(conn, tlsConfig) + + err = tlsConn.Handshake() + assert.Equal(t, err, nil) + + resp, _ := nsq.ReadResponse(tlsConn) + frameType, data, _ := nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) + + compressConn := flate.NewReader(tlsConn) + + resp, _ = nsq.ReadResponse(compressConn) + frameType, data, _ = nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) +} + +func TestTLSSnappy(t *testing.T) { + log.SetOutput(ioutil.Discard) + defer log.SetOutput(os.Stdout) + + *verbose = true + options := NewNsqdOptions() + options.snappyEnabled = true + options.tlsCert = "./test/cert.pem" + options.tlsKey = "./test/key.pem" + tcpAddr, _, nsqd := mustStartNSQd(options) + defer nsqd.Exit() + + conn, err := mustConnectNSQd(tcpAddr) + assert.Equal(t, err, nil) + + data := identifyFeatureNegotiation(t, conn, map[string]interface{}{"tls_v1": true, "snappy": true}) + r := struct { + TLSv1 bool `json:"tls_v1"` + Snappy bool `json:"snappy"` + }{} + err = json.Unmarshal(data, &r) + assert.Equal(t, err, nil) + assert.Equal(t, r.TLSv1, true) + assert.Equal(t, r.Snappy, true) + + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + } + tlsConn := tls.Client(conn, tlsConfig) + + err = tlsConn.Handshake() + assert.Equal(t, err, nil) + + resp, _ := nsq.ReadResponse(tlsConn) + frameType, data, _ := nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) + + compressConn := snappystream.NewReader(tlsConn, snappystream.SkipVerifyChecksum) + + resp, _ = nsq.ReadResponse(compressConn) + frameType, data, _ = nsq.UnpackResponse(resp) + log.Printf("frameType: %d, data: %s", frameType, data) + assert.Equal(t, frameType, nsq.FrameTypeResponse) + assert.Equal(t, data, []byte("OK")) +} + func BenchmarkProtocolV2Exec(b *testing.B) { b.StopTimer() log.SetOutput(ioutil.Discard) diff --git a/test.sh b/test.sh index 0e80ea308..78d07fc11 100755 --- a/test.sh +++ b/test.sh @@ -21,8 +21,9 @@ popd >/dev/null pushd nsqd >/dev/null go build rm -f *.dat -echo "starting nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem" -./nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem >/dev/null 2>&1 & +cmd="./nsqd --data-path=/tmp --lookupd-tcp-address=127.0.0.1:4160 --tls-cert=./test/cert.pem --tls-key=./test/key.pem" +echo "starting $cmd" +$cmd >/dev/null 2>&1 & NSQD_PID=$! popd >/dev/null