Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

抽出readMessage函数,方便对该函数单独benchmark #15

Merged
merged 1 commit into from
Sep 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
298 changes: 148 additions & 150 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ type Conn struct {
bp bytespool.BytesPool // 实验某些特性加的字段

delayWrite
readHeadArray [enum.MaxFrameHeaderSize]byte
fragmentFramePayload []byte // 存放分片帧的缓冲区
bufioPayload []byte
fragmentFrameHeader *frame.FrameHeader
}

func setNoDelay(c net.Conn, noDelay bool) error {
Expand Down Expand Up @@ -132,10 +136,36 @@ func decode(payload []byte) ([]byte, error) {
return o.Bytes(), nil
}

func (c *Conn) ReadLoop() error {
func (c *Conn) ReadLoop() (err error) {
c.OnOpen(c)

return c.readLoop()
defer func() {
// c.OnClose(c, err)
c.Close()
if c.fr.IsInit() {
defer func() {
c.fr.Release()
c.fr.BufPtr()
}()
}
}()

if c.br != nil {
newSize := int(1024 * c.bufioMultipleTimesPayloadSize)
if c.br.Size() != newSize {
// TODO sync.Pool管理
(*bufio2.Reader2)(unsafe.Pointer(c.br)).ResetBuf(make([]byte, newSize))
}
// bufio 模式才会使用payload
c.bufioPayload = *bytespool.GetBytes(1024 + enum.MaxFrameHeaderSize)
}

for {
err = c.readMessage()
if err != nil {
return err
}
}
}

func (c *Conn) StartReadLoop() {
Expand All @@ -144,7 +174,7 @@ func (c *Conn) StartReadLoop() {
}()
}

func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, payload *[]byte) (f frame.Frame, err error) {
func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, bufioPayload *[]byte) (f frame.Frame, err error) {
if c.readTimeout > 0 {
err = c.c.SetReadDeadline(time.Now().Add(c.readTimeout))
if err != nil {
Expand All @@ -156,7 +186,7 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, payload
if c.fr.IsInit() {
f, err = frame.ReadFrameFromWindows(&c.fr, headArray, c.windowsMultipleTimesPayloadSize)
} else {
f, err = frame.ReadFrameFromReader(c.br, headArray, payload)
f, err = frame.ReadFrameFromReader(c.br, headArray, bufioPayload)
}
if err != nil {
c.Callback.OnClose(c, err)
Expand All @@ -172,186 +202,154 @@ func (c *Conn) readDataFromNet(headArray *[enum.MaxFrameHeaderSize]byte, payload
}

// 读取websocket frame.Frame的循环
func (c *Conn) readLoop() error {
var f frame.Frame
var fragmentFrameHeader *frame.FrameHeader

defer c.Close()

var err error
var op opcode.Opcode

if c.fr.IsInit() {
defer func() {
c.fr.Release()
c.fr.BufPtr()
}()
func (c *Conn) readMessage() (err error) {
// 从网络读取数据
f, err := c.readDataFromNet(&c.readHeadArray, &c.bufioPayload)
if err != nil {
return err
}

var fragmentFrameBuf []byte
var headArray [enum.MaxFrameHeaderSize]byte

var payload []byte
if c.br != nil {
newSize := int(1024 * c.bufioMultipleTimesPayloadSize)
if c.br.Size() != newSize {
// TODO sync.Pool管理
(*bufio2.Reader2)(unsafe.Pointer(c.br)).ResetBuf(make([]byte, newSize))
}
// bufio 模式才会使用payload
payload = *bytespool.GetBytes(1024 + enum.MaxFrameHeaderSize)
op := f.Opcode
if c.fragmentFrameHeader != nil {
op = c.fragmentFrameHeader.Opcode
}

for {
rsv1 := f.GetRsv1()
// 检查Rsv1 rsv2 Rfd, errsv3
if rsv1 && c.failRsv1(op) || f.GetRsv2() || f.GetRsv3() {
err = fmt.Errorf("%w:Rsv1(%t) Rsv2(%t) rsv2(%t) compression:%t", ErrRsv123, rsv1, f.GetRsv2(), f.GetRsv3(), c.compression)
return c.writeErrAndOnClose(ProtocolError, err)
}

// 从网络读取数据
f, err = c.readDataFromNet(&headArray, &payload)
if err != nil {
return err
}
fin := f.GetFin()
if c.fragmentFrameHeader != nil && !f.Opcode.IsControl() {
if f.Opcode == 0 {
c.fragmentFramePayload = append(c.fragmentFramePayload, f.Payload...)

op = f.Opcode
if fragmentFrameHeader != nil {
op = fragmentFrameHeader.Opcode
}
// 分段的在这返回
if fin {
// 解压缩
if c.fragmentFrameHeader.GetRsv1() && c.decompression {
tempBuf, err := decode(c.fragmentFramePayload)
if err != nil {
return err
}
c.fragmentFramePayload = tempBuf
}
// 这里的check按道理应该放到f.Fin前面, 会更符合rfc的标准, 前提是c.utf8Check修改成流式解析
// TODO c.utf8Check 修改成流式解析
if c.fragmentFrameHeader.Opcode == opcode.Text && !c.utf8Check(c.fragmentFramePayload) {
c.Callback.OnClose(c, ErrTextNotUTF8)
return ErrTextNotUTF8
}

rsv1 := f.GetRsv1()
// 检查Rsv1 rsv2 Rsv3
if rsv1 && c.failRsv1(op) || f.GetRsv2() || f.GetRsv3() {
err = fmt.Errorf("%w:Rsv1(%t) Rsv2(%t) rsv2(%t) compression:%t", ErrRsv123, rsv1, f.GetRsv2(), f.GetRsv3(), c.compression)
return c.writeErrAndOnClose(ProtocolError, err)
c.Callback.OnMessage(c, c.fragmentFrameHeader.Opcode, c.fragmentFramePayload)
c.fragmentFramePayload = c.fragmentFramePayload[0:0]
c.fragmentFrameHeader = nil
}
return nil
}

fin := f.GetFin()
if fragmentFrameHeader != nil && !f.Opcode.IsControl() {
if f.Opcode == 0 {
fragmentFrameBuf = append(fragmentFrameBuf, f.Payload...)

// 分段的在这返回
if fin {
// 解压缩
if fragmentFrameHeader.GetRsv1() && c.decompression {
tempBuf, err := decode(fragmentFrameBuf)
if err != nil {
return err
}
fragmentFrameBuf = tempBuf
}
// 这里的check按道理应该放到f.Fin前面, 会更符合rfc的标准, 前提是c.utf8Check修改成流式解析
// TODO c.utf8Check 修改成流式解析
if fragmentFrameHeader.Opcode == opcode.Text && !c.utf8Check(fragmentFrameBuf) {
c.Callback.OnClose(c, ErrTextNotUTF8)
return ErrTextNotUTF8
}
c.writeErrAndOnClose(ProtocolError, ErrFrameOpcode)
return ErrFrameOpcode
}

c.Callback.OnMessage(c, fragmentFrameHeader.Opcode, fragmentFrameBuf)
fragmentFrameBuf = fragmentFrameBuf[0:0]
fragmentFrameHeader = nil
}
continue
if f.Opcode == opcode.Text || f.Opcode == opcode.Binary {
if !fin {
prevFrame := f.FrameHeader
// 第一次分段
if len(c.fragmentFramePayload) == 0 {
c.fragmentFramePayload = append(c.fragmentFramePayload, f.Payload...)
f.Payload = nil
}

c.writeErrAndOnClose(ProtocolError, ErrFrameOpcode)
return ErrFrameOpcode
// 让fragmentFrame的Payload指向readBuf, readBuf 原引用直接丢弃
c.fragmentFrameHeader = &prevFrame
return
}

if f.Opcode == opcode.Text || f.Opcode == opcode.Binary {
if !fin {
prevFrame := f.FrameHeader
// 第一次分段
if len(fragmentFrameBuf) == 0 {
fragmentFrameBuf = append(fragmentFrameBuf, f.Payload...)
f.Payload = nil
}

// 让fragmentFrame的Payload指向readBuf, readBuf 原引用直接丢弃
fragmentFrameHeader = &prevFrame
continue
if rsv1 && c.decompression {
// 不分段的解压缩
f.Payload, err = decode(f.Payload)
if err != nil {
return err
}
}

if rsv1 && c.decompression {
// 不分段的解压缩
f.Payload, err = decode(f.Payload)
if err != nil {
return err
}
if f.Opcode == opcode.Text {
if !c.utf8Check(f.Payload) {
c.c.Close()
c.Callback.OnClose(c, ErrTextNotUTF8)
return ErrTextNotUTF8
}
}

if f.Opcode == opcode.Text {
if !c.utf8Check(f.Payload) {
c.c.Close()
c.Callback.OnClose(c, ErrTextNotUTF8)
return ErrTextNotUTF8
}
}
c.Callback.OnMessage(c, f.Opcode, f.Payload)
return
}

c.Callback.OnMessage(c, f.Opcode, f.Payload)
continue
if f.Opcode == Close || f.Opcode == Ping || f.Opcode == Pong {
// 对方发的控制消息太大
if f.PayloadLen > maxControlFrameSize {
c.writeErrAndOnClose(ProtocolError, ErrMaxControlFrameSize)
return ErrMaxControlFrameSize
}
// Close, Ping, Pong 不能分片
if !fin {
c.writeErrAndOnClose(ProtocolError, ErrNOTBeFragmented)
return ErrNOTBeFragmented
}

if f.Opcode == Close || f.Opcode == Ping || f.Opcode == Pong {
// 对方发的控制消息太大
if f.PayloadLen > maxControlFrameSize {
c.writeErrAndOnClose(ProtocolError, ErrMaxControlFrameSize)
return ErrMaxControlFrameSize
}
// Close, Ping, Pong 不能分片
if !fin {
c.writeErrAndOnClose(ProtocolError, ErrNOTBeFragmented)
return ErrNOTBeFragmented
if f.Opcode == Close {
if len(f.Payload) == 0 {
return c.writeErrAndOnClose(NormalClosure, ErrClosePayloadTooSmall)
}

if f.Opcode == Close {
if len(f.Payload) == 0 {
return c.writeErrAndOnClose(NormalClosure, ErrClosePayloadTooSmall)
}

if len(f.Payload) < 2 {
return c.writeErrAndOnClose(ProtocolError, ErrClosePayloadTooSmall)
}

if !c.utf8Check(f.Payload[2:]) {
return c.writeErrAndOnClose(ProtocolError, ErrTextNotUTF8)
}
if len(f.Payload) < 2 {
return c.writeErrAndOnClose(ProtocolError, ErrClosePayloadTooSmall)
}

code := binary.BigEndian.Uint16(f.Payload)
if !validCode(code) {
return c.writeErrAndOnClose(ProtocolError, ErrCloseValue)
}
if !c.utf8Check(f.Payload[2:]) {
return c.writeErrAndOnClose(ProtocolError, ErrTextNotUTF8)
}

// 回敬一个close包
if err := c.WriteTimeout(Close, f.Payload, 2*time.Second); err != nil {
return err
}
code := binary.BigEndian.Uint16(f.Payload)
if !validCode(code) {
return c.writeErrAndOnClose(ProtocolError, ErrCloseValue)
}

err = bytesToCloseErrMsg(f.Payload)
c.Callback.OnClose(c, err)
// 回敬一个close包
if err := c.WriteTimeout(Close, f.Payload, 2*time.Second); err != nil {
return err
}

if f.Opcode == Ping {
// 回一个pong包
if c.replyPing {
if err := c.WriteTimeout(Pong, f.Payload, 2*time.Second); err != nil {
c.Callback.OnClose(c, err)
return err
}
c.Callback.OnMessage(c, f.Opcode, f.Payload)
continue
}
}
err = bytesToCloseErrMsg(f.Payload)
c.Callback.OnClose(c, err)
return err
}

if f.Opcode == Pong && c.ignorePong {
continue
if f.Opcode == Ping {
// 回一个pong包
if c.replyPing {
if err := c.WriteTimeout(Pong, f.Payload, 2*time.Second); err != nil {
c.Callback.OnClose(c, err)
return err
}
c.Callback.OnMessage(c, f.Opcode, f.Payload)
return
}
}

c.Callback.OnMessage(c, f.Opcode, nil)
continue
if f.Opcode == Pong && c.ignorePong {
return
}
// 检查Opcode
c.writeErrAndOnClose(ProtocolError, ErrOpcode)
return ErrOpcode

c.Callback.OnMessage(c, f.Opcode, nil)
return
}
// 检查Opcode
c.writeErrAndOnClose(ProtocolError, ErrOpcode)
return ErrOpcode
}

type wrapBuffer struct {
Expand Down
4 changes: 0 additions & 4 deletions server_handshake.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ var (
strWebSocketKey = "Sec-WebSocket-Key"
)

type ConnOption struct {
Config
}

func writeHeaderVal(w io.Writer, val []byte) (err error) {
if _, err = w.Write(val); err != nil {
return
Expand Down
4 changes: 4 additions & 0 deletions server_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ package quickws

type ServerOption func(*ConnOption)

type ConnOption struct {
Config
}

// 1.配置压缩和解压缩
func WithServerDecompressAndCompress() ServerOption {
return func(o *ConnOption) {
Expand Down
Loading