diff --git a/.gitignore b/.gitignore index b75c4ce..1e0860c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *swp +/cover.cov autobahn-testsuite autobahn-client-testsuite-darwin-arm64 autobahn-client-testsuite-linux-amd64 diff --git a/benchmark_mask_test.go b/benchmark_rand_test.go similarity index 84% rename from benchmark_mask_test.go rename to benchmark_rand_test.go index 70d385f..9ecc8a1 100644 --- a/benchmark_mask_test.go +++ b/benchmark_rand_test.go @@ -7,7 +7,6 @@ import ( func Benchmark_Rand_Uint32(t *testing.B) { for i := 0; i < t.N; i++ { - // newMask(maskValue[:]) _ = rand.Uint32() } } diff --git a/config.go b/config.go index 2121ef8..0545a42 100644 --- a/config.go +++ b/config.go @@ -28,6 +28,7 @@ type Config struct { decompression bool // 开启解压缩功能 compression bool // 开启压缩功能 ignorePong bool // 忽略pong消息 + disableBufioClearHack bool // 关闭bufio的clear hack优化 utf8Check func([]byte) bool // utf8检查 readTimeout time.Duration windowsMultipleTimesPayloadSize float32 // 设置几倍的payload大小 diff --git a/conn.go b/conn.go index a82e6e0..c86b9c7 100644 --- a/conn.go +++ b/conn.go @@ -364,38 +364,38 @@ func (c *Conn) WriteMessage(op Opcode, writeBuf []byte) (err error) { return frame.WriteFrame(&c.fw, c.c, writeBuf, rsv1, c.client, op, maskValue) } -// 这是一个不安全的方法, writeBuf的格式必须是 14个字节的空白长度+需要写的payload组成 -func (c *Conn) WriteMessageUnsafe(op Opcode, writeBuf []byte) (err error) { - // if op == opcode.Text { - // if !c.utf8Check(writeBuf) { - // return ErrTextNotUTF8 - // } - // } - - // rsv1 := c.compression && (op == opcode.Text || op == opcode.Binary) - // if rsv1 { - // var out wrapBuffer - // w := compressNoContextTakeover(&out, defaultCompressionLevel) - // if _, err = io.Copy(w, bytes.NewReader(writeBuf)); err != nil { - // return - // } - - // if err = w.Close(); err != nil { - // return - // } - // writeBuf = out.Bytes() - // } - - // // f.Opcode = op - // // f.PayloadLen = int64(len(writeBuf)) - // maskValue := uint32(0) - // if c.client { - // maskValue = rand.Uint32() - // } - - // return frame.WriteFrame(&c.fw, c.c, writeBuf, rsv1, c.client, op, maskValue) - return -} +// TODO 这是一个不安全的方法, writeBuf的格式必须是 14个字节的空白长度+需要写的payload组成 +// func (c *Conn) WriteMessageUnsafe(op Opcode, writeBuf []byte) (err error) { +// if op == opcode.Text { +// if !c.utf8Check(writeBuf) { +// return ErrTextNotUTF8 +// } +// } + +// rsv1 := c.compression && (op == opcode.Text || op == opcode.Binary) +// if rsv1 { +// var out wrapBuffer +// w := compressNoContextTakeover(&out, defaultCompressionLevel) +// if _, err = io.Copy(w, bytes.NewReader(writeBuf)); err != nil { +// return +// } + +// if err = w.Close(); err != nil { +// return +// } +// writeBuf = out.Bytes() +// } + +// // f.Opcode = op +// // f.PayloadLen = int64(len(writeBuf)) +// maskValue := uint32(0) +// if c.client { +// maskValue = rand.Uint32() +// } + +// return frame.WriteFrame(&c.fw, c.c, writeBuf, rsv1, c.client, op, maskValue) +// return +// } func (c *Conn) SetDeadline(t time.Time) error { return c.c.SetDeadline(t) diff --git a/parser_ext.go b/parser_ext.go index e62140c..2305534 100644 --- a/parser_ext.go +++ b/parser_ext.go @@ -2,7 +2,6 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// TODO 等重写http1.1 解析器, 再把这代码重写下 package quickws import ( diff --git a/quickws_server_test.go b/quickws_server_test.go new file mode 100644 index 0000000..8b01333 --- /dev/null +++ b/quickws_server_test.go @@ -0,0 +1,419 @@ +package quickws + +import ( + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" +) + +// 测试服务端握手失败的情况 +func Test_Server_HandshakeFail(t *testing.T) { + // u := NewUpgrade() + t.Run("local config:case:method fail", func(t *testing.T) { + run := int32(0) + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := Upgrade(w, r) + if err == nil { + t.Error("upgrade method fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + defer ts.Close() + + url := ts.URL + req, err := http.NewRequest("POST", url, nil) + if err != nil { + t.Error(err) + } + http.DefaultClient.Do(req) + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:method fail") + } + }) + + t.Run("global config:case:method fail", func(t *testing.T) { + run := int32(0) + upgrade := NewUpgrade() + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := upgrade.Upgrade(w, r) + if err == nil { + t.Error("upgrade method fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + defer ts.Close() + + url := ts.URL + req, err := http.NewRequest("POST", url, nil) + if err != nil { + t.Error(err) + } + http.DefaultClient.Do(req) + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:method fail") + } + }) + + t.Run("local config:case:http proto fail", func(t *testing.T) { + run := int32(0) + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := Upgrade(w, r) + if err == nil { + t.Error("upgrade http proto fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + c.Write([]byte("GET / HTTP/1.0\r\nHost: localhost:8080\r\n\r\n")) + c.Close() + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:http proto fail") + } + }) + + t.Run("global config:case:http proto fail", func(t *testing.T) { + run := int32(0) + upgrade := NewUpgrade() + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := upgrade.Upgrade(w, r) + if err == nil { + t.Error("upgrade http proto fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + c.Write([]byte("GET / HTTP/1.0\r\n\r\n")) + // c.Write([]byte("GET / HTTP/1.0\r\nHost: localhost:8080\r\n\r\n")) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:http proto fail") + } + }) + + t.Run("local config:case:host empty", func(t *testing.T) { + run := int32(0) + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := Upgrade(w, r) + if err == nil { + t.Error("upgrade host fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + defer ts.Close() + + url := strings.ReplaceAll(ts.URL, "http://", "") + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + c.Write([]byte("GET / HTTP/1.1\r\nHost: \r\n\r\n")) + defer c.Close() + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:host empty") + } + }) + + t.Run("global config:case:upgrade fail", func(t *testing.T) { + run := int32(0) + upgrade := NewUpgrade() + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := upgrade.Upgrade(w, r) + if err == nil { + t.Error("upgrade : upgrade field fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: xx\r\n\r\n", url)) + c.Write(wbuf) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:upgrade field fail") + } + }) + + t.Run("local config:case:upgrade fail", func(t *testing.T) { + run := int32(0) + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := Upgrade(w, r) + if err == nil { + t.Error("upgrade : upgrade field fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: xx\r\n\r\n", url)) + c.Write(wbuf) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:upgrade field fail") + } + }) + + t.Run("global config:case:Connection fail", func(t *testing.T) { + run := int32(0) + upgrade := NewUpgrade() + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := upgrade.Upgrade(w, r) + if err == nil { + t.Error("upgrade : Connection field fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: xx\r\n\r\n", url)) + c.Write(wbuf) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:Connection field fail") + } + }) + + t.Run("local config:case:Connection fail", func(t *testing.T) { + run := int32(0) + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := Upgrade(w, r) + if err == nil { + t.Error("upgrade : Connection field fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: xx\r\n\r\n", url)) + c.Write(wbuf) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:Connection field fail") + } + }) + + t.Run("global config:case: Sec-WebSocket-Key fail", func(t *testing.T) { + run := int32(0) + upgrade := NewUpgrade() + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := upgrade.Upgrade(w, r) + if err == nil { + t.Error("upgrade : Connection field fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", url)) + c.Write(wbuf) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:Connection field fail") + } + }) + + t.Run("local config:case: Sec-WebSocket-Key fail", func(t *testing.T) { + run := int32(0) + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := Upgrade(w, r) + if err == nil { + t.Error("upgrade : Connection field fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\n\r\n", url)) + c.Write(wbuf) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:Connection field fail") + } + }) + + t.Run("global config:case: Sec-WebSocket-Version fail", func(t *testing.T) { + run := int32(0) + upgrade := NewUpgrade() + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := upgrade.Upgrade(w, r) + if err == nil { + t.Error("upgrade : Connection field fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: key\r\n\r\n", url)) + c.Write(wbuf) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:Connection field fail") + } + }) + + t.Run("local config:case: Sec-WebSocket-Version fail", func(t *testing.T) { + run := int32(0) + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := Upgrade(w, r) + if err == nil { + t.Error("upgrade : Connection field fail") + } + atomic.AddInt32(&run, int32(1)) + done <- true + })) + + url := strings.ReplaceAll(ts.URL, "http://", "") + defer ts.Close() + c, err := net.Dial("tcp", url) + if err != nil { + t.Error(err) + } + wbuf := []byte(fmt.Sprintf("GET / HTTP/1.1\r\nHost: %s\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: key\r\n\r\n", url)) + c.Write(wbuf) + c.Close() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:Connection field fail") + } + }) +} diff --git a/server.go b/server_handshake.go similarity index 100% rename from server.go rename to server_handshake.go diff --git a/server_options.go b/server_options.go index 282cea0..8760126 100644 --- a/server_options.go +++ b/server_options.go @@ -18,7 +18,7 @@ import "time" type ServerOption func(*ConnOption) -// 设置TCP_NODELAY +// 设置TCP_NODELAY 为false, 开启nagle算法 func WithServerTCPDelay() ServerOption { return func(o *ConnOption) { o.tcpNoDelay = false @@ -32,12 +32,14 @@ func WithServerDisableUTF8Check() ServerOption { } } +// 设置读超时时间 func WithServerReadTimeout(t time.Duration) ServerOption { return func(o *ConnOption) { o.readTimeout = t } } +// 配置回调函数 func WithServerCallback(cb Callback) ServerOption { return func(o *ConnOption) { o.Callback = cb @@ -84,9 +86,8 @@ func WithServerIgnorePong() ServerOption { // 只有解析方式是窗口的时候才有效 func WithWindowsMultipleTimesPayloadSize(mt float32) ServerOption { return func(o *ConnOption) { - // 如果mt < 1.0, 直接panic if mt < 1.0 { - panic("multipleTimesPayloadSize must >= 1.0") + mt = 1.0 } o.windowsMultipleTimesPayloadSize = mt } @@ -105,3 +106,10 @@ func WithBufioParseMode() ServerOption { o.parseMode = ParseModeBufio } } + +// 关闭bufio clear hack优化 +func WithDisableBufioClearHack() ServerOption { + return func(o *ConnOption) { + o.disableBufioClearHack = true + } +} diff --git a/server_test.go b/server_profile_test.go similarity index 100% rename from server_test.go rename to server_profile_test.go diff --git a/upgrade.go b/upgrade.go index 59dfad9..720b993 100644 --- a/upgrade.go +++ b/upgrade.go @@ -70,7 +70,9 @@ func upgradeInner(w http.ResponseWriter, r *http.Request, conf *Config) (c *Conn if conf.parseMode == ParseModeWindows { // 这里不需要rw,直接使用conn conn, rw, err = hi.Hijack() - bufio2.ClearReadWriter(rw) + if !conf.disableBufioClearHack { + bufio2.ClearReadWriter(rw) + } rsp.ClearRsp(w) rw = nil } else { diff --git a/utils.go b/utils.go index 18a2d1a..ca086b7 100644 --- a/utils.go +++ b/utils.go @@ -46,10 +46,6 @@ func StringToBytes(s string) (b []byte) { // return *(*string)(unsafe.Pointer(&b)) // } -func newMask(mask []byte) { - rand.Read(mask) -} - func secWebSocketAccept() string { // rfc规定是16字节 var key [16]byte