diff --git a/autobahn/autobahn-client-testsuite.go b/autobahn/autobahn-client-testsuite.go index d4f1dff..e7ff3e6 100644 --- a/autobahn/autobahn-client-testsuite.go +++ b/autobahn/autobahn-client-testsuite.go @@ -11,8 +11,8 @@ import ( // https://github.com/snapview/tokio-tungstenite/blob/master/examples/autobahn-client.rs const ( - // host = "ws://192.168.128.44:9003" - host = "ws://127.0.0.1:9003" + host = "ws://192.168.128.44:9003" + // host = "ws://127.0.0.1:9003" agent = "quickws" ) @@ -25,6 +25,7 @@ func (e *echoHandler) OnOpen(c *quickws.Conn) { } func (e *echoHandler) OnMessage(c *quickws.Conn, op quickws.Opcode, msg []byte) { + fmt.Println("OnMessage:", c, op, msg) if op == quickws.Text || op == quickws.Binary { // os.WriteFile("./debug.dat", msg, 0o644) if err := c.WriteTimeout(op, msg, 1*time.Minute); err != nil { @@ -81,7 +82,8 @@ func runTest(caseNo int) { done := make(chan struct{}) c, err := quickws.Dial(fmt.Sprintf("%s/runCase?case=%d&agent=%s", host, caseNo, agent), quickws.WithClientReplyPing(), - quickws.WithClientDecompression(), + // quickws.WithClientCompression(), + quickws.WithClientDecompressAndCompress(), quickws.WithClientCallback(&echoHandler{done: done}), ) if err != nil { diff --git a/benchmark_read_write_message_test.go b/benchmark_read_write_message_test.go index f0a85e2..96da20f 100644 --- a/benchmark_read_write_message_test.go +++ b/benchmark_read_write_message_test.go @@ -18,8 +18,14 @@ import ( "net" "testing" "time" + + "github.com/antlabs/wsutil/enum" + "github.com/antlabs/wsutil/frame" + "github.com/antlabs/wsutil/opcode" ) +var noMaskData = []byte{0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f} + // Read reads data from the connection. // Read can be made to time out and return an error after a fixed // time limit; see SetDeadline and SetReadDeadline. @@ -105,13 +111,13 @@ func Benchmark_WriteFrame(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { // - var f frame - f.fin = true + var f frame.Frame + f.Fin = true - f.opcode = Binary - f.payload = buf.Bytes() - f.payloadLen = int64(buf.Len()) - writeFrame(&buf, f) + f.Opcode = opcode.Binary + f.Payload = buf.Bytes() + f.PayloadLen = int64(buf.Len()) + frame.WriteFrame(&buf, f) buf.Reset() } } @@ -127,7 +133,7 @@ func Benchmark_WriteMessage(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(Binary, buf) + c.WriteMessage(opcode.Binary, buf) buf2.Reset() } } @@ -145,7 +151,7 @@ func Benchmark_ReadMessage(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - c.WriteMessage(Binary, wbuf) + c.WriteMessage(opcode.Binary, wbuf) c.ReadLoop() buf2.Reset() } @@ -153,11 +159,11 @@ func Benchmark_ReadMessage(b *testing.B) { func Benchmark_ReadFrame(b *testing.B) { r := bytes.NewReader(noMaskData) - var headArray [maxFrameHeaderSize]byte + var headArray [enum.MaxFrameHeaderSize]byte for i := 0; i < b.N; i++ { r.Reset(noMaskData) - _, _, err := readHeader(r, &headArray) + _, _, err := frame.ReadHeader(r, &headArray) if err != nil { b.Fatal(err) } diff --git a/bytes_pool.go b/bytes_pool.go deleted file mode 100644 index 0345255..0000000 --- a/bytes_pool.go +++ /dev/null @@ -1,70 +0,0 @@ -package quickws - -import ( - "sync" -) - -const ( - page = 1024 - maxIndex = 64 -) - -func selectIndex(n int) int { - index := n / page - return index -} - -var pools = make([]sync.Pool, 0, maxIndex) - -var upgradeRespPool = sync.Pool{ - New: func() interface{} { - buf := make([]byte, 256) - return &buf - }, -} - -func init() { - for i := 1; i <= maxIndex; i++ { - j := i - pools = append(pools, sync.Pool{ - New: func() interface{} { - buf := make([]byte, j*page+maxFrameHeaderSize) - return &buf - }, - }) - } -} - -func getBytes(n int) (rv *[]byte) { - if n <= maxFrameHeaderSize { - return pools[0].Get().(*[]byte) - } - - index := selectIndex(n - maxFrameHeaderSize - 1) - if index >= len(pools) { - rv := make([]byte, n+maxFrameHeaderSize) - return &rv - } - - return pools[index].Get().(*[]byte) -} - -func putBytes(bytes *[]byte) { - if cap(*bytes) < maxFrameHeaderSize { - panic("putBytes: bytes is too small") - } - newLen := cap(*bytes) - maxFrameHeaderSize - 1 - index := selectIndex(newLen) - if index >= len(pools) { - return - } - pools[index].Put(bytes) -} - -func getUpgradeRespBytes() *[]byte { - return upgradeRespPool.Get().(*[]byte) -} - -func putUpgradeRespBytes(bytes *[]byte) { - upgradeRespPool.Put(bytes) -} diff --git a/bytes_pool_test.go b/bytes_pool_test.go deleted file mode 100644 index 3a8d3f1..0000000 --- a/bytes_pool_test.go +++ /dev/null @@ -1,43 +0,0 @@ -package quickws - -import ( - "testing" - "unsafe" - - "github.com/stretchr/testify/assert" -) - -func Test_Index(t *testing.T) { - for i := 0; i <= 1024+maxFrameHeaderSize; i++ { - i2 := i - if i2 >= maxFrameHeaderSize { - i2 -= (maxFrameHeaderSize + 1) - } - index := selectIndex(i2) - assert.Equal(t, index, 0) - } - - for i := 1024 + maxFrameHeaderSize + 1; i <= 2*1024+maxFrameHeaderSize; i++ { - i2 := i - i2 -= (maxFrameHeaderSize + 1) - index := selectIndex(i2) - assert.Equal(t, index, 1) - } - - for i := 1024*2 + maxFrameHeaderSize + 1; i <= 3*1024+maxFrameHeaderSize; i++ { - i2 := i - i2 -= (maxFrameHeaderSize + 1) - index := selectIndex(i2) - assert.Equal(t, index, 2) - } -} - -func Test_GetBytes_Address(t *testing.T) { - var m map[unsafe.Pointer]bool - for i := 0; i < 10; i++ { - p := getBytes(1) - if m[unsafe.Pointer(p)] { - t.Fatal("duplicate pointer") - } - } -} diff --git a/callback.go b/callback.go index 2d74e88..336c465 100644 --- a/callback.go +++ b/callback.go @@ -13,11 +13,13 @@ // limitations under the License. package quickws -type Callback interface { - OnOpen(*Conn) - OnMessage(*Conn, Opcode, []byte) - OnClose(*Conn, error) -} +type ( + Callback interface { + OnOpen(*Conn) + OnMessage(*Conn, Opcode, []byte) + OnClose(*Conn, error) + } +) type DefCallback struct{} diff --git a/client.go b/client.go index 8a855fa..2ae7f9b 100644 --- a/client.go +++ b/client.go @@ -23,6 +23,10 @@ import ( "net/url" "strings" "time" + + "github.com/antlabs/wsutil/bytespool" + "github.com/antlabs/wsutil/enum" + "github.com/antlabs/wsutil/fixedreader" ) var ( @@ -205,16 +209,16 @@ func (d *DialOption) Dial() (c *Conn, err error) { } // 处理下已经在bufio里面的数据,后面都是直接操作net.Conn,所以需要取出bufio里面已读取的数据 - var fr *fixedReader + var fr *fixedreader.FixedReader if brw.Reader.Buffered() > 0 { b, err := brw.Reader.Peek(brw.Reader.Buffered()) if err != nil { return nil, err } - b2 := getBytes(len(b) + maxFrameHeaderSize) + b2 := bytespool.GetBytes(len(b) + enum.MaxFrameHeaderSize) copy(*b2, b) - fr = newBuffer(conn, b2) - fr.w = len(b) + fr = fixedreader.NewFixedReader(conn, b2) + fr.W = len(b) } // fmt.Println(brw.Reader.Buffered()) return newConn(conn, true, d.config, fr), nil diff --git a/conn.go b/conn.go index 577fd55..ad6f132 100644 --- a/conn.go +++ b/conn.go @@ -22,6 +22,12 @@ import ( "net" "time" "unicode/utf8" + + "github.com/antlabs/wsutil/bytespool" + "github.com/antlabs/wsutil/enum" + "github.com/antlabs/wsutil/fixedreader" + "github.com/antlabs/wsutil/frame" + "github.com/antlabs/wsutil/opcode" ) const ( @@ -34,30 +40,30 @@ type Conn struct { c net.Conn client bool config - fr *fixedReader + fr *fixedreader.FixedReader } -func newConn(c net.Conn, client bool, conf config, fr *fixedReader) *Conn { +func newConn(c net.Conn, client bool, conf config, fr *fixedreader.FixedReader) *Conn { return &Conn{c: c, client: client, config: conf, fr: fr} } func (c *Conn) writeErrAndOnClose(code StatusCode, userErr error) error { defer c.Callback.OnClose(c, userErr) - if err := c.WriteTimeout(Close, statusCodeToBytes(code), 2*time.Second); err != nil { + if err := c.WriteTimeout(opcode.Close, statusCodeToBytes(code), 2*time.Second); err != nil { return err } return userErr } -func (c *Conn) failRsv1(op Opcode) bool { +func (c *Conn) failRsv1(op opcode.Opcode) bool { // 解压缩没有开启 if !c.decompression { return true } // 不是text和binary - if op != Text && op != Binary { + if op != opcode.Text && op != opcode.Binary { return true } @@ -81,7 +87,7 @@ func (c *Conn) ReadLoop() error { return c.readLoop() } -func (c *Conn) readDataFromNet(fixedBuf *fixedReader, headArray *[maxFrameHeaderSize]byte) (f frame, err error) { +func (c *Conn) readDataFromNet(fixedBuf *fixedreader.FixedReader, headArray *[enum.MaxFrameHeaderSize]byte) (f frame.Frame, err error) { if c.readTimeout > 0 { err = c.c.SetReadDeadline(time.Now().Add(c.readTimeout)) if err != nil { @@ -90,7 +96,7 @@ func (c *Conn) readDataFromNet(fixedBuf *fixedReader, headArray *[maxFrameHeader } } - f, err = readFrame(fixedBuf, headArray) + f, err = frame.ReadFrame(fixedBuf, headArray) if err != nil { c.Callback.OnClose(c, err) return @@ -104,27 +110,27 @@ func (c *Conn) readDataFromNet(fixedBuf *fixedReader, headArray *[maxFrameHeader return } -// 读取websocket frame的循环 +// 读取websocket frame.Frame的循环 func (c *Conn) readLoop() error { - var f frame - var fragmentFrameHeader *frameHeader + var f frame.Frame + var fragmentFrameHeader *frame.FrameHeader defer c.Close() var err error - var op Opcode + var op opcode.Opcode - var fixedBuf *fixedReader + var fixedBuf *fixedreader.FixedReader if c.fr != nil { fixedBuf = c.fr } else { // 默认最小1k + 14 - fixedBuf = newBuffer(c.c, getBytes(1024+maxFrameHeaderSize)) + fixedBuf = fixedreader.NewFixedReader(c.c, bytespool.GetBytes(1024+enum.MaxFrameHeaderSize)) } - defer fixedBuf.release() + defer fixedBuf.Release() var fragmentFrameBuf []byte - var headArray [maxFrameHeaderSize]byte + var headArray [enum.MaxFrameHeaderSize]byte for { @@ -134,39 +140,39 @@ func (c *Conn) readLoop() error { return err } - op = f.opcode + op = f.Opcode if fragmentFrameHeader != nil { - op = fragmentFrameHeader.opcode + op = fragmentFrameHeader.Opcode } - // 检查rsv1 rsv2 rsv3 - if f.rsv1 && c.failRsv1(op) || f.rsv2 || f.rsv3 { - err = fmt.Errorf("%w:rsv1(%t) rsv2(%t) rsv2(%t)", ErrRsv123, f.rsv1, f.rsv2, f.rsv3) + // 检查Rsv1 rsv2 Rsv3 + if f.Rsv1 && c.failRsv1(op) || f.Rsv2 || f.Rsv3 { + err = fmt.Errorf("%w:Rsv1(%t) Rsv2(%t) rsv2(%t) compression:%t", ErrRsv123, f.Rsv1, f.Rsv2, f.Rsv3, c.compression) return c.writeErrAndOnClose(ProtocolError, err) } - if fragmentFrameHeader != nil && !f.opcode.isControl() { - if f.opcode == 0 { - fragmentFrameBuf = append(fragmentFrameBuf, f.payload...) + if fragmentFrameHeader != nil && !f.Opcode.IsControl() { + if f.Opcode == 0 { + fragmentFrameBuf = append(fragmentFrameBuf, f.Payload...) // 分段的在这返回 - if f.fin { + if f.Fin { // 解压缩 - if fragmentFrameHeader.rsv1 && c.decompression { + if fragmentFrameHeader.Rsv1 && c.decompression { tmpeBuf, err := decode(fragmentFrameBuf) if err != nil { return err } fragmentFrameBuf = tmpeBuf } - // 这里的check按道理应该放到f.fin前面, 会更符合rfc的标准, 前提是utf8.Valid修改成流式解析 + // 这里的check按道理应该放到f.Fin前面, 会更符合rfc的标准, 前提是utf8.Valid修改成流式解析 // TODO utf8.Valid 修改成流式解析 - if fragmentFrameHeader.opcode == Text && !utf8.Valid(fragmentFrameBuf) { + if fragmentFrameHeader.Opcode == opcode.Text && !utf8.Valid(fragmentFrameBuf) { c.Callback.OnClose(c, ErrTextNotUTF8) return ErrTextNotUTF8 } - c.Callback.OnMessage(c, fragmentFrameHeader.opcode, fragmentFrameBuf) + c.Callback.OnMessage(c, fragmentFrameHeader.Opcode, fragmentFrameBuf) fragmentFrameBuf = fragmentFrameBuf[0:0] fragmentFrameHeader = nil } @@ -177,96 +183,96 @@ func (c *Conn) readLoop() error { return ErrFrameOpcode } - // 检查opcode - switch f.opcode { - case Text, Binary: - if !f.fin { - prevFrame := f.frameHeader + // 检查Opcode + switch f.Opcode { + case opcode.Text, opcode.Binary: + if !f.Fin { + prevFrame := f.FrameHeader // 第一次分段 if len(fragmentFrameBuf) == 0 { - fragmentFrameBuf = append(fragmentFrameBuf, f.payload...) - f.payload = nil + fragmentFrameBuf = append(fragmentFrameBuf, f.Payload...) + f.Payload = nil } - // 让fragmentFrame的payload指向readBuf, readBuf 原引用直接丢弃 + // 让fragmentFrame的Payload指向readBuf, readBuf 原引用直接丢弃 fragmentFrameHeader = &prevFrame continue } - if f.rsv1 && c.decompression { + if f.Rsv1 && c.decompression { // 不分段的解压缩 - f.payload, err = decode(f.payload) + f.Payload, err = decode(f.Payload) if err != nil { return err } } - if f.opcode == Text { - if !utf8.Valid(f.payload) { + if f.Opcode == opcode.Text { + if !utf8.Valid(f.Payload) { c.c.Close() c.Callback.OnClose(c, ErrTextNotUTF8) return ErrTextNotUTF8 } } - c.Callback.OnMessage(c, f.opcode, f.payload) + c.Callback.OnMessage(c, f.Opcode, f.Payload) case Close, Ping, Pong: // 对方发的控制消息太大 - if f.payloadLen > maxControlFrameSize { + if f.PayloadLen > maxControlFrameSize { c.writeErrAndOnClose(ProtocolError, ErrMaxControlFrameSize) return ErrMaxControlFrameSize } // Close, Ping, Pong 不能分片 - if !f.fin { + if !f.Fin { c.writeErrAndOnClose(ProtocolError, ErrNOTBeFragmented) return ErrNOTBeFragmented } - if f.opcode == Close { - if len(f.payload) == 0 { + if f.Opcode == Close { + if len(f.Payload) == 0 { return c.writeErrAndOnClose(NormalClosure, ErrClosePayloadTooSmall) } - if len(f.payload) < 2 { + if len(f.Payload) < 2 { return c.writeErrAndOnClose(ProtocolError, ErrClosePayloadTooSmall) } - if !utf8.Valid(f.payload[2:]) { + if !utf8.Valid(f.Payload[2:]) { return c.writeErrAndOnClose(ProtocolError, ErrTextNotUTF8) } - code := binary.BigEndian.Uint16(f.payload) + code := binary.BigEndian.Uint16(f.Payload) if !validCode(code) { return c.writeErrAndOnClose(ProtocolError, ErrCloseValue) } // 回敬一个close包 - if err := c.WriteTimeout(Close, f.payload, 2*time.Second); err != nil { + if err := c.WriteTimeout(Close, f.Payload, 2*time.Second); err != nil { return err } - err = bytesToCloseErrMsg(f.payload) + err = bytesToCloseErrMsg(f.Payload) c.Callback.OnClose(c, err) return err } - if f.opcode == Ping { + if f.Opcode == Ping { // 回一个pong包 if c.replyPing { - if err := c.WriteTimeout(Pong, f.payload, 2*time.Second); err != nil { + if err := c.WriteTimeout(Pong, f.Payload, 2*time.Second); err != nil { c.Callback.OnClose(c, err) return err } - c.Callback.OnMessage(c, f.opcode, f.payload) + c.Callback.OnMessage(c, f.Opcode, f.Payload) continue } } - if f.opcode == Pong && c.ignorePong { + if f.Opcode == Pong && c.ignorePong { continue } - c.Callback.OnMessage(c, f.opcode, nil) + c.Callback.OnMessage(c, f.Opcode, nil) default: c.writeErrAndOnClose(ProtocolError, ErrOpcode) return ErrOpcode @@ -284,17 +290,17 @@ func (w *wrapBuffer) Close() error { } func (c *Conn) WriteMessage(op Opcode, writeBuf []byte) (err error) { - var f frame + var f frame.Frame - if op == Text { + if op == opcode.Text { if !utf8.Valid(writeBuf) { return ErrTextNotUTF8 } } - f.fin = true - f.rsv1 = c.compression && (op == Text || op == Binary) - if f.rsv1 { + f.Fin = true + f.Rsv1 = c.compression && (op == opcode.Text || op == opcode.Binary) + if f.Rsv1 { var out wrapBuffer w := compressNoContextTakeover(&out, defaultCompressionLevel) if _, err = io.Copy(w, bytes.NewReader(writeBuf)); err != nil { @@ -307,15 +313,15 @@ func (c *Conn) WriteMessage(op Opcode, writeBuf []byte) (err error) { writeBuf = out.Bytes() } - f.opcode = op - f.payload = writeBuf - f.payloadLen = int64(len(writeBuf)) + f.Opcode = op + f.Payload = writeBuf + f.PayloadLen = int64(len(writeBuf)) if c.client { - f.mask = true - newMask(f.maskValue[:]) + f.Mask = true + newMask(f.MaskValue[:]) } - return writeFrame(c.c, f) + return frame.WriteFrame(c.c, f) } func (c *Conn) SetDeadline(t time.Time) error { @@ -333,7 +339,7 @@ func (c *Conn) WriteTimeout(op Opcode, data []byte, t time.Duration) (err error) func (c *Conn) WriteCloseTimeout(sc StatusCode, t time.Duration) (err error) { buf := statusCodeToBytes(sc) - return c.WriteTimeout(Close, buf, t) + return c.WriteTimeout(opcode.Close, buf, t) } func (c *Conn) Close() error { diff --git a/fixedreader.go b/fixedreader.go deleted file mode 100644 index c2807c8..0000000 --- a/fixedreader.go +++ /dev/null @@ -1,145 +0,0 @@ -package quickws - -import ( - "errors" - "io" -) - -var errNegativeRead = errors.New("bufio: reader returned negative count from Read") - -// 固定大小的fixedReader, 所有的内存都是提前分配好的 -// 标准库的bufio.Reader不能传递一个固定大小的buf, 导致控制力度会差点 -type fixedReader struct { - buf []byte - p *[]byte - rd io.Reader // reader provided by the client - r, w int // buf read and write positions - err error -} - -// newBuffer returns a new Buffer whose buffer has the specified size. -func newBuffer(r io.Reader, buf *[]byte) *fixedReader { - return &fixedReader{ - rd: r, - buf: *buf, - p: buf, - } -} - -func (b *fixedReader) release() error { - if b.p != nil { - putBytes(b.p) - b.buf = nil - b.p = nil - } - return nil -} - -func (b *fixedReader) readErr() error { - err := b.err - b.err = nil - return err -} - -// 将缓存区重置为一个新的buf -func (b *fixedReader) reset(buf *[]byte) { - if len(*buf) < len(b.buf[b.r:b.w]) { - panic("new buf size is too small") - } - - copy(*buf, b.buf[b.r:b.w]) - b.w -= b.r - b.r = 0 - b.p = buf - b.buf = *buf -} - -// 返回底层[]byte的长度 -func (b *fixedReader) Len() int { - return len(b.buf) -} - -func (b *fixedReader) ptr() *[]byte { - return b.p -} - -func (b *fixedReader) bytes() []byte { - return b.buf -} - -// 返回剩余可写的缓存区大小 -func (b *fixedReader) writeCap() int { - return len(b.buf[b.w:]) -} - -// 返回剩余可用的缓存区大小 -func (b *fixedReader) available() int64 { - return int64(len(b.buf[b.w:]) + b.r) -} - -// 左移缓存区 -func (b *fixedReader) leftMove() { - if b.r == 0 { - return - } - copy(b.buf, b.buf[b.r:b.w]) - b.w -= b.r - b.r = 0 -} - -// 返回可写的缓存区 -func (b *fixedReader) writeCapBytes() []byte { - return b.buf[b.w:] -} - -func (b *fixedReader) cloneAvailable() *fixedReader { - return &fixedReader{rd: b.rd, buf: b.buf[b.w:]} -} - -func (b *fixedReader) Buffered() int { return b.w - b.r } - -// 这和一般read接口中不一样 -// 传入的p 一定会满足这个大小 -func (b *fixedReader) Read(p []byte) (n int, err error) { - if cap(b.buf) < cap(p) { - panic("fixedReader.Reader buf size is too small: cap(b.buf) < cap(p)") - } - - n = len(p) - if n == 0 { - if b.Buffered() > 0 { - return 0, nil - } - return 0, b.readErr() - } - - var n1 int - for { - - if b.r == b.w || len(b.buf[b.r:b.w]) < len(p) { - if b.err != nil { - return 0, b.readErr() - } - if b.r == b.w { - b.r = 0 - b.w = 0 - } - n1, b.err = b.rd.Read(b.buf[b.w:]) - if n1 < 0 { - panic(errNegativeRead) - } - if n1 == 0 { - return 0, b.readErr() - } - b.w += n1 - } - - if len(b.buf[b.r:b.w]) < len(p) { - continue - } - n1 = copy(p, b.buf[b.r:b.w]) - b.r += n1 - - return n, nil - } -} diff --git a/fixedreader_test.go b/fixedreader_test.go deleted file mode 100644 index 5b27d0d..0000000 --- a/fixedreader_test.go +++ /dev/null @@ -1,229 +0,0 @@ -package quickws - -import ( - "bytes" - "fmt" - "io" - "testing" - - "github.com/stretchr/testify/assert" -) - -// func splitString(s string, chunkSize int) []string { -// var chunks []string -// for i := 0; i < len(s); i += chunkSize { -// end := i + chunkSize -// if end > len(s) { -// end = len(s) -// } -// chunks = append(chunks, s[i:end]) -// } -// return chunks -// } - -func Test_Reader_Small(t *testing.T) { - var out bytes.Buffer - - tmp := append([]byte(nil), testTextMessage64kb...) - err := writeMessgae(&out, Text, tmp, true) - // hexString := hex.EncodeToString(out.Bytes()) - // // 在每两个字符之间插入空格 - // spacedHexString := strings.Join(splitString(hexString, 2), ", ") - // fmt.Printf("header: %+v\n", spacedHexString[:100]) - assert.NoError(t, err) - - r := newBuffer(&out, getBytes(1024+maxFrameHeaderSize)) - - var headArray [14]byte - f, err := readFrame(r, &headArray) - assert.NoError(t, err) - // err = os.WriteFile("./test_reader.dat", f.payload, 0o644) - - assert.NoError(t, err) - assert.Equal(t, f.payload, testTextMessage64kb) -} - -func Test_Reader_WriteMulti_ReadOne(t *testing.T) { - var out bytes.Buffer - - for i := 1024 * 63; i <= 1024*63+1; i++ { - need := make([]byte, 0, i) - got := make([]byte, 0, i) - for j := 0; j < i; j++ { - need = append(need, byte(j)) - got = append(got, byte(j)) - } - - for j := 0; j < 1; j++ { - err := writeMessgae(&out, Text, need, true) - assert.NoError(t, err) - err = writeMessgae(&out, Text, need, true) - assert.NoError(t, err) - } - fmt.Printf("i = %d, need: len(%d), write.size:%d\n", i, len(need), out.Len()) - - b := getBytes(1024 + maxFrameHeaderSize) - r := newBuffer(&out, b) - var headArray [14]byte - for j := 0; j < 2; j++ { - - f, err := readFrame(r, &headArray) - assert.NoError(t, err) - - assert.NoError(t, err) - // TODO - if j == 0 { - continue - } - if !bytes.Equal(f.payload, got) { - t.Fatalf("bad test index:%d\n", i) - return - } - // assert.Equal(t, f.payload, got, fmt.Sprintf("index:%d", i)) - if err != nil { - return - } - } - putBytes(r.ptr()) - out.Reset() - } -} - -// 测试只写一次数据包,但是分多次读取 -func Test_Reader_WriteOne_ReadMulti(t *testing.T) { - var out bytes.Buffer - - var headArray [14]byte - for i := 1031; i <= 1024*64; i++ { - // for i := 2046; i <= 2048; i++ { - need := make([]byte, 0, i) - got := make([]byte, 0, i) - for j := 0; j < i; j++ { - need = append(need, byte(j)) - got = append(got, byte(j)) - } - - err := writeMessgae(&out, Text, need, true) - assert.NoError(t, err) - - b := getBytes(1024 + maxFrameHeaderSize) - r := newBuffer(&out, b) - - f, err := readFrame(r, &headArray) - assert.NoError(t, err) - if err != nil { - return - } - - assert.NoError(t, err) - // TODO - if i == 0 { - continue - } - assert.Equal(t, f.payload, got, fmt.Sprintf("index:%d", i)) - putBytes(r.ptr()) - out.Reset() - } -} - -func Test_Reset(t *testing.T) { - var out bytes.Buffer - out.Write([]byte("1234")) - r := newBuffer(&out, getBytes(1024+maxFrameHeaderSize)) - - small := make([]byte, 2) - - r.Read(small) - r.reset(getBytes(1024*2 + maxFrameHeaderSize)) - assert.Equal(t, r.bytes()[:2], []byte("34")) - // assert.Equal(t, r.free()[:2], []byte{0, 0}) -} - -func Test_Reader_WriteMulti_ReadOne_64512(t *testing.T) { - var out bytes.Buffer - - for i := 64512; i <= 64512; i++ { - need := make([]byte, 0, i) - got := make([]byte, 0, i) - for j := 0; j < i; j++ { - need = append(need, byte(j)) - got = append(got, byte(j)) - } - - for j := 0; j < 1; j++ { - err := writeMessgae(&out, Text, need, true) - assert.NoError(t, err) - err = writeMessgae(&out, Text, need, true) - assert.NoError(t, err) - } - fmt.Printf("i = %d, need: len(%d), write.size:%d\n", i, len(need), out.Len()) - - b := getBytes(1024 + maxFrameHeaderSize) - r := newBuffer(&out, b) - var headArray [14]byte - for j := 0; j < 2; j++ { - - f, err := readFrame(r, &headArray) - assert.NoError(t, err) - - assert.NoError(t, err) - // TODO - if j == 0 { - continue - } - if !bytes.Equal(f.payload, got) { - t.Fatalf("bad test index:%d\n", i) - return - } - // assert.Equal(t, f.payload, got, fmt.Sprintf("index:%d", i)) - if err != nil { - return - } - } - putBytes(r.ptr()) - out.Reset() - } -} - -func Test_Reader_WriteMulti_ReadOne_65536(t *testing.T) { - var out bytes.Buffer - - var headArray [14]byte - for i := 65536; i <= 64512; i++ { - need := make([]byte, 0, i) - got := make([]byte, 0, i) - for j := 0; j < i; j++ { - need = append(need, byte(j)) - got = append(got, byte(j)) - } - - for j := 0; j < 1; j++ { - err := writeMessgae(&out, Text, need, true) - assert.NoError(t, err) - err = writeMessgae(&out, Text, need, true) - assert.NoError(t, err) - } - fmt.Printf("i = %d, need: len(%d), write.size:%d\n", i, len(need), out.Len()) - - b := getBytes(1024 + maxFrameHeaderSize) - r := newBuffer(&out, b) - for j := 0; j < 2; j++ { - - f, err := readFrame(r, &headArray) - if err != io.EOF { - assert.NoError(t, err) - } - - if !bytes.Equal(f.payload, got) { - t.Fatalf("bad test index:%d\n", i) - return - } - // assert.Equal(t, f.payload, got, fmt.Sprintf("index:%d", i)) - if err != nil { - return - } - } - putBytes(r.ptr()) - out.Reset() - } -} diff --git a/fixedwriter.go b/fixedwriter.go deleted file mode 100644 index 015573d..0000000 --- a/fixedwriter.go +++ /dev/null @@ -1,25 +0,0 @@ -package quickws - -import "fmt" - -type fixedWriter struct { - buf []byte - w int -} - -func (fw *fixedWriter) Write(p []byte) (n int, err error) { - if len(fw.buf[fw.w:]) < len(p) { - panic(fmt.Sprintf("fixedWriter: buf is too small: %d:%d < %d", len(fw.buf[fw.w:]), cap(fw.buf), cap(p))) - } - n = copy(fw.buf[fw.w:], p) - fw.w += n - return n, nil -} - -func (fw *fixedWriter) Len() int { - return fw.w -} - -func (fw *fixedWriter) Bytes() []byte { - return fw.buf[:fw.w] -} diff --git a/fixedwriter_test.go b/fixedwriter_test.go deleted file mode 100644 index 515f767..0000000 --- a/fixedwriter_test.go +++ /dev/null @@ -1,21 +0,0 @@ -package quickws - -import "testing" - -func Test_FixedWriter(t *testing.T) { - fw := &fixedWriter{buf: make([]byte, 1024)} - n, err := fw.Write([]byte("hello")) - if err != nil { - t.Errorf("fw.Write() = %v, want nil", err) - } - if n != 5 { - t.Errorf("fw.Write() = %d, want 5", n) - } - - if fw.Len() != 5 { - t.Errorf("fw.Len() = %d, want 5", fw.Len()) - } - if string(fw.Bytes()) != "hello" { - t.Errorf("fw.Bytes() = %s, want hello", fw.Bytes()) - } -} diff --git a/frame.go b/frame.go deleted file mode 100644 index e5cc9cb..0000000 --- a/frame.go +++ /dev/null @@ -1,278 +0,0 @@ -// Copyright 2021-2023 antlabs. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package quickws - -import ( - "encoding/binary" - "errors" - "io" - "math" - - "github.com/antlabs/wsutil/mask" -) - -const ( - // 根据5.2描述, 满打满算, 最多14字节 - maxFrameHeaderSize = 14 -) - -var ErrFramePayloadLength = errors.New("error frame payload length") - -type frameHeader struct { - payloadLen int64 - opcode Opcode - maskValue [4]byte - rsv1 bool - rsv2 bool - rsv3 bool - mask bool - fin bool -} - -type frame struct { - frameHeader - payload []byte -} - -func readFrame(r *fixedReader, headArray *[maxFrameHeaderSize]byte) (f frame, err error) { - // 如果剩余可写缓存区放不下一个frame header, 就把数据往前移动 - // 所有的的buf分配都是paydload + frame head 的长度, 挪完之后,肯定是能放下一个frame header的 - if r.Len()-r.r < maxFrameHeaderSize { - r.leftMove() - if r.Len() < maxFrameHeaderSize { - panic("readFrame r.Len() < maxFrameHeaderSize") - } - } - - h, _, err := readHeader(r, headArray) - if err != nil { - return f, err - } - - // 如果缓存区不够, 重新分配 - - // h.payloadLen 是要读取body的总数据 - // h.w - h.r 是已经读取未处理的数据 - // 还需要读取的数据等于 h.payloadLen - (h.w - h.r) - - // 已读取未处理的数据 - readUnhandle := int64(r.w - r.r) - // 情况 1,需要读的长度 > 剩余可用空间(未写的+已经被读取走的) - if h.payloadLen-readUnhandle > r.available() { - // 1.取得旧的buf - oldBuf := r.ptr() - // 2.获取新的buf - newBuf := getBytes(int(h.payloadLen) + maxFrameHeaderSize) - // 3.重置缓存区 - r.reset(newBuf) - // 4.将旧的buf放回池子里 - putBytes(oldBuf) - - // 情况 2。 空间是够的,需要挪一挪, 把已经读过的覆盖掉 - } else if h.payloadLen-readUnhandle > int64(r.writeCap()) { - r.leftMove() - } - - // 返回可写的缓存区 - payload := r.writeCapBytes() - // 前面的reset已经保证了,buffer的大小是够的 - needRead := h.payloadLen - readUnhandle - - if needRead > 0 { - // payload是一块干净可写的空间,使用needRead框下范围 - payload = payload[:needRead] - // 新建一对新的r w指向尾部的内存区域 - right := r.cloneAvailable() - if _, err = io.ReadFull(right, payload); err != nil { - return f, err - } - - // right 也有可能超读, 直接加上payload的长度,会把超读的数据给丢了 - // 为什么会发生超读呢,right持的buf 会 >= payload的长度 - r.w += right.w - } - - f.payload = r.bytes()[r.r : r.r+int(h.payloadLen)] - f.frameHeader = h - r.r += int(h.payloadLen) - if h.mask { - key := binary.LittleEndian.Uint32(h.maskValue[:]) - mask.Mask(f.payload, key) - } - - return f, nil -} - -func readHeader(r io.Reader, headArray *[maxFrameHeaderSize]byte) (h frameHeader, size int, err error) { - // var headArray [maxFrameHeaderSize]byte - head := (*headArray)[:2] - - n, err := io.ReadFull(r, head) - if err != nil { - return - } - if n != 2 { - err = io.ErrUnexpectedEOF - return - } - size = 2 - - h.fin = head[0]&(1<<7) > 0 - h.rsv1 = head[0]&(1<<6) > 0 - h.rsv2 = head[0]&(1<<5) > 0 - h.rsv3 = head[0]&(1<<4) > 0 - h.opcode = Opcode(head[0] & 0xF) - - have := 0 - h.mask = head[1]&(1<<7) > 0 - if h.mask { - have += 4 - size += 4 - } - - h.payloadLen = int64(head[1] & 0x7F) - - switch { - // 长度 - case h.payloadLen >= 0 && h.payloadLen <= 125: - if h.payloadLen == 0 && !h.mask { - return - } - case h.payloadLen == 126: - // 2字节长度 - have += 2 - size += 2 - case h.payloadLen == 127: - // 8字节长度 - have += 8 - size += 8 - default: - // 预期之外的, 直接报错 - return h, 0, ErrFramePayloadLength - } - - head = head[:have] - _, err = io.ReadFull(r, head) - if err != nil { - return - } - - switch h.payloadLen { - case 126: - h.payloadLen = int64(binary.BigEndian.Uint16(head[:2])) - head = head[2:] - case 127: - h.payloadLen = int64(binary.BigEndian.Uint64(head[:8])) - head = head[8:] - } - - if h.mask { - copy(h.maskValue[:], head) - } - - return -} - -// https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 -// (the most significant bit MUST be 0) -func writeHeader(w io.Writer, h frameHeader) (err error) { - var head [maxFrameHeaderSize]byte - - if h.fin { - head[0] |= 1 << 7 - } - - if h.rsv1 { - head[0] |= 1 << 6 - } - - if h.rsv2 { - head[0] |= 1 << 5 - } - - if h.rsv3 { - head[0] |= 1 << 5 - } - - head[0] |= byte(h.opcode & 0xF) - - have := 2 - switch { - case h.payloadLen <= 125: - head[1] = byte(h.payloadLen) - case h.payloadLen <= math.MaxUint16: - head[1] = 126 - binary.BigEndian.PutUint16(head[2:], uint16(h.payloadLen)) - have += 2 // 2前 - default: - head[1] = 127 - binary.BigEndian.PutUint64(head[2:], uint64(h.payloadLen)) - have += 8 - } - - if h.mask { - head[1] |= 1 << 7 - have += copy(head[have:], h.maskValue[:]) - } - - _, err = w.Write(head[:have]) - return err -} - -func writeMessgae(w io.Writer, op Opcode, writeBuf []byte, isClient bool) (err error) { - var f frame - f.fin = true - f.opcode = op - f.payload = writeBuf - f.payloadLen = int64(len(writeBuf)) - defer func() { - f.payload = nil - }() - if isClient { - f.mask = true - newMask(f.maskValue[:]) - } - - return writeFrame(w, f) -} - -func writeFrame(w io.Writer, f frame) (err error) { - buf := getBytes(len(f.payload) + maxFrameHeaderSize) - - tmpWriter := fixedWriter{buf: *buf} - var ws io.Writer = &tmpWriter - - defer func() { - tmpWriter.buf = nil - putBytes(buf) - }() - if err = writeHeader(ws, f.frameHeader); err != nil { - return - } - - wIndex := tmpWriter.Len() - _, err = ws.Write(f.payload) - if err != nil { - return - } - if f.mask { - key := binary.LittleEndian.Uint32(f.maskValue[:]) - mask.Mask(tmpWriter.Bytes()[wIndex:], key) - } - - // fmt.Printf("writeFrame %#v\n", tmpWriter.Bytes()) - _, err = w.Write(tmpWriter.Bytes()) - return err -} diff --git a/frame_test.go b/frame_test.go deleted file mode 100644 index 9d8f055..0000000 --- a/frame_test.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2021-2023 antlabs. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -package quickws - -import ( - "bytes" - "fmt" - "io" - "testing" - - "github.com/stretchr/testify/assert" -) - -var ( - noMaskData = []byte{0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f} - haveMaskData = []byte{0x81, 0x85, 0x37, 0xfa, 0x21, 0x3d, 0x7f, 0x9f, 0x4d, 0x51, 0x58} -) - -func Test_Frame_Read_Size(t *testing.T) { - var out bytes.Buffer - err := writeMessgae(&out, Text, nil, true) - assert.NoError(t, err) - var headArray [14]byte - outLen := out.Len() - _, size, err := readHeader(&out, &headArray) - assert.NoError(t, err) - assert.Equal(t, size, outLen) - fmt.Printf("%d:%d\n", size, outLen) -} - -func Test_Frame_Read_NoMask(t *testing.T) { - r := bytes.NewReader(noMaskData) - - var headArray [14]byte - h, _, err := readHeader(r, &headArray) - assert.NoError(t, err) - all, err := io.ReadAll(r) - assert.NoError(t, err) - - // fmt.Printf("opcode:%d", h.opcode) - assert.Equal(t, string(all), "Hello") - assert.Equal(t, h.payloadLen, int64(len("Hello"))) -} - -func Test_Frame_Mask_Read_And_Write(t *testing.T) { - r := bytes.NewReader(haveMaskData) - - buf := make([]byte, 512) - rr := newBuffer(r, &buf) - var headArray [14]byte - f, err := readFrame(rr, &headArray) - assert.NoError(t, err) - - assert.Equal(t, string(f.payload[:f.payloadLen]), "Hello") - - var w bytes.Buffer - assert.NoError(t, writeFrame(&w, f)) - assert.Equal(t, w.Bytes(), haveMaskData) -} - -func Test_Frame_Write_NoMask(t *testing.T) { - // br := bytes.NewReader([]byte{0x81, 0x05, 0x48, 0x65, 0x6c, 0x6c, 0x6f}) - - var w bytes.Buffer - var h frameHeader - h.payloadLen = int64(5) - h.opcode = 1 - h.fin = true - assert.NoError(t, writeHeader(&w, h)) - _, err := w.WriteString("Hello") - - assert.NoError(t, err) - assert.Equal(t, w.Bytes(), noMaskData) -} diff --git a/go.mod b/go.mod index 1a58c41..36deae3 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/antlabs/quickws go 1.20 require ( - github.com/antlabs/wsutil v0.0.1 - github.com/stretchr/testify v1.7.0 + github.com/antlabs/wsutil v0.0.2 + github.com/stretchr/testify v1.8.4 ) require ( - github.com/davecgh/go-spew v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index de03c0b..e88ce71 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,12 @@ -github.com/antlabs/wsutil v0.0.1 h1:Fhw47NrEDP/8PIAcFNmPS60jIR54TungCFuRq/aEzbM= -github.com/antlabs/wsutil v0.0.1/go.mod h1:7ec5eUM7nmKW+Oi6F1I58iatOeL9k+yIsfOh1zh910g= -github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/antlabs/wsutil v0.0.2 h1:qOHEDo+m/VZcRV/pd954lSprY/cddLhwUVQKk6JWPv4= +github.com/antlabs/wsutil v0.0.2/go.mod h1:qL4K9yILLmkvukmcHkruNAZoPAf7CSbnxm8DA6Y4hzY= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= -github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/message_test.go b/message_test.go index 46ee077..4859774 100644 --- a/message_test.go +++ b/message_test.go @@ -11,6 +11,7 @@ import ( "testing" "time" + "github.com/antlabs/wsutil/opcode" "github.com/stretchr/testify/assert" ) @@ -32,7 +33,7 @@ type testMessageHandler struct { output bool } -func (t *testMessageHandler) OnMessage(c *Conn, op Opcode, msg []byte) { +func (t *testMessageHandler) OnMessage(c *Conn, op opcode.Opcode, msg []byte) { need := append([]byte(nil), t.need...) atomic.StoreInt32(&t.callbed, 1) if t.count == 0 { diff --git a/opcode.go b/opcode.go index 9305905..e498c24 100644 --- a/opcode.go +++ b/opcode.go @@ -14,46 +14,23 @@ package quickws -type Opcode uint8 +import "github.com/antlabs/wsutil/opcode" + +type Opcode = opcode.Opcode const ( - Continuation Opcode = iota - Text - Binary + Continuation = opcode.Continuation + Text = opcode.Text + Binary = opcode.Binary // 3 - 7保留 _ // 3 _ _ // 5 _ - _ // 7 - Close - Ping - Pong + _ // 7 + Close = opcode.Close + Ping = opcode.Ping + Pong = opcode.Pong ) var ErrClose = "websocket" - -func (c Opcode) String() string { - switch { - case c >= 3 && c <= 7: - return "control" - case c == Text: - return "text" - case c == Binary: - return "binary" - case c == Close: - return "close" - case c == Ping: - return "ping" - case c == Pong: - return "pong" - case c == Continuation: - return "continuation" - default: - return "unknown" - } -} - -func (c Opcode) isControl() bool { - return (c & (1 << 3)) > 0 -} diff --git a/opcode_test.go b/opcode_test.go deleted file mode 100644 index 6a0e817..0000000 --- a/opcode_test.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2021-2023 antlabs. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package quickws - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func Test_Control(t *testing.T) { - assert.False(t, Text.isControl()) - assert.False(t, Binary.isControl()) - assert.True(t, Ping.isControl()) - assert.True(t, Pong.isControl()) - assert.True(t, Close.isControl()) -} diff --git a/server.go b/server.go index ec4adbc..c7be4e5 100644 --- a/server.go +++ b/server.go @@ -21,6 +21,8 @@ import ( "io" "net/http" "strings" + + "github.com/antlabs/wsutil/bytespool" ) var ( @@ -63,11 +65,11 @@ func Upgrade(w http.ResponseWriter, r *http.Request, opts ...OptionServer) (c *C conf.decompression = needDecompression(r.Header) } - buf := getUpgradeRespBytes() + buf := bytespool.GetUpgradeRespBytes() tmpWriter := bytes.NewBuffer((*buf)[:0]) defer func() { - putUpgradeRespBytes(buf) + bytespool.PutUpgradeRespBytes(buf) tmpWriter = nil }() if err = prepareWriteResponse(r, tmpWriter, conf.config); err != nil { diff --git a/utils.go b/utils.go index 1e9eb40..18a2d1a 100644 --- a/utils.go +++ b/utils.go @@ -100,7 +100,7 @@ func maybeCompressionDecompression(header http.Header) bool { } _, s := ext["server_no_context_takeover"] _, c := ext["client_no_context_takeover"] - return s && c + return s || c } return false