Skip to content

Commit

Permalink
websocket: simplify cache append and concurrent free logic
Browse files Browse the repository at this point in the history
  • Loading branch information
lesismal committed Jun 24, 2024
1 parent c00b749 commit 985998b
Showing 1 changed file with 73 additions and 96 deletions.
169 changes: 73 additions & 96 deletions nbhttp/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,16 @@ ErrExit:
c.Close()
}

func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool, bool, error) {
func (c *Conn) nextFrame() (int, MessageType, []byte, bool, bool, bool, error) {
var (
opcode MessageType
body []byte
ok, fin, res1, res2, res3 bool
err error
data = c.bytesCached
l = int64(len(data))
headLen = int64(2)
total int64
)
if l >= 2 {
opcode = MessageType(data[0] & 0xF)
Expand All @@ -297,38 +299,41 @@ func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool,
}

if c.isMessageTooLarge(len(c.message) + int(bodyLen)) {
return data, 0, nil, false, false, false, ErrMessageTooLarge
return 0, 0, nil, false, false, false, ErrMessageTooLarge
}

if (bodyLen > maxControlFramePayloadSize) &&
((opcode == PingMessage) || (opcode == PongMessage) || (opcode == CloseMessage)) {
return data, 0, nil, false, false, false, ErrControlMessageTooBig
return 0, 0, nil, false, false, false, ErrControlMessageTooBig
}

if bodyLen >= 0 {
masked := (data[1] & 0x80) != 0
if masked {
headLen += 4
}
total := headLen + bodyLen
total = headLen + bodyLen
if l >= total {
body = data[headLen:total]
if masked {
maskXOR(body, data[headLen-4:headLen])
}

ok = true
data = data[total:l]
err = c.validFrame(opcode, fin, res1, res2, res3, c.expectingFragments)
}
}
}

return data, opcode, body, ok, fin, res1, err
return int(total), opcode, body, ok, fin, res1, err
}

// Read .
func (c *Conn) Parse(data []byte) error {
if len(data) == 0 {
return nil
}

c.mux.Lock()
if c.closed {
c.mux.Unlock()
Expand All @@ -341,12 +346,12 @@ func (c *Conn) Parse(data []byte) error {
return nbhttp.ErrTooLong
}

var appended = false
var allocator = c.Engine.BodyAllocator
if len(c.bytesCached) > 0 {
if len(c.bytesCached) == 0 {
c.bytesCached = allocator.Malloc(len(data))
copy(c.bytesCached, data)
} else {
c.bytesCached = allocator.Append(c.bytesCached, data...)
data = c.bytesCached
appended = true
}
c.mux.Unlock()

Expand All @@ -357,7 +362,18 @@ func (c *Conn) Parse(data []byte) error {
var protocolMessage []byte
var opcode MessageType
var ok, fin, compress bool
var totalFrameSize int

updateCache := func(consumed int) {
l := len(c.bytesCached)
if l == consumed {
c.Engine.BodyAllocator.Free(c.bytesCached)
c.bytesCached = nil
} else {
copy(c.bytesCached, data[consumed:l])
c.bytesCached = c.bytesCached[:l-consumed]
}
}
releaseBuf := func() {
if len(frame) > 0 {
allocator.Free(frame)
Expand All @@ -378,31 +394,30 @@ func (c *Conn) Parse(data []byte) error {
err = net.ErrClosed
return
}
data, opcode, body, ok, fin, compress, err = c.nextFrame(data)
totalFrameSize, opcode, body, ok, fin, compress, err = c.nextFrame()
if err != nil {
return
}
if !ok {
return
}

bl := len(body)
switch opcode {
case FragmentMessage, TextMessage, BinaryMessage:
if c.msgType == 0 {
c.msgType = opcode
c.compress = compress
}
bl := len(body)
if c.dataFrameHandler != nil {
if bl > 0 {
frame = allocator.Malloc(bl)
copy(frame, body)
}
if c.msgType == TextMessage && len(frame) > 0 && !c.Engine.CheckUtf8(frame) {
c.Conn.Close()
err = ErrInvalidUtf8
return
}
if bl > 0 && c.dataFrameHandler != nil {
frame = allocator.Malloc(bl)
copy(frame, body)
// if compressed, should check utf8 after decompressed the whole message.
// if c.msgType == TextMessage && len(frame) > 0 && !c.Engine.CheckUtf8(frame) {
// c.Conn.Close()
// err = ErrInvalidUtf8
// return
// }
}
if c.messageHandler != nil {
if bl > 0 {
Expand All @@ -416,36 +431,6 @@ func (c *Conn) Parse(data []byte) error {
if fin {
message = c.message
c.message = nil
}
}
case PingMessage, PongMessage, CloseMessage:
if len(body) > 0 {
protocolMessage = allocator.Malloc(len(body))
copy(protocolMessage, body)
}
default:
err = ErrInvalidFragmentMessage
return
}
}()

if err != nil {
releaseBuf()
if errors.Is(err, ErrMessageTooLarge) || errors.Is(err, ErrControlMessageTooBig) {
c.WriteClose(1009, err.Error())
}
return err
}

if ok {
switch opcode {
case FragmentMessage, TextMessage, BinaryMessage:
if c.dataFrameHandler != nil {
c.handleDataFrame(c.msgType, fin, frame)
frame = nil
}
if fin {
if c.messageHandler != nil {
if c.compress {
var b []byte
var rc io.ReadCloser
Expand All @@ -460,64 +445,56 @@ func (c *Conn) Parse(data []byte) error {
rc.Close()
if err != nil {
releaseBuf()
return err
return
}
}
c.handleMessage(c.msgType, message)
message = nil
c.compress = false
c.expectingFragments = false
c.msgType = 0
} else {
c.expectingFragments = true
}
c.compress = false
c.expectingFragments = false
c.msgType = 0
} else {
c.expectingFragments = true
}
case PingMessage, PongMessage, CloseMessage:
c.handleProtocolMessage(opcode, protocolMessage)
protocolMessage = nil
if bl > 0 {
protocolMessage = allocator.Malloc(len(body))
copy(protocolMessage, body)
}
default:
releaseBuf()
return ErrInvalidFragmentMessage
err = ErrInvalidFragmentMessage
return
}
} else {
goto Exit
}

if len(data) == 0 {
goto Exit
}
}
updateCache(totalFrameSize)
}()

Exit:
releaseBuf()
c.mux.Lock()
defer c.mux.Unlock()
if c.closed {
return net.ErrClosed
}
// The data bytes were not all consumed, need to recache the current bytes left:
if len(data) > 0 {
// The data bytes were appended to the tail of the previous chaced data:
if appended {
// If data bytes were consumed, move data to the head of the cached bytes,
// else the data is same as the cached bytes, nothing to do.
if len(data) < len(c.bytesCached) {
c.bytesCached = c.bytesCached[:len(data)]
copy(c.bytesCached, data)
if err != nil {
if errors.Is(err, ErrMessageTooLarge) || errors.Is(err, ErrControlMessageTooBig) {
c.WriteClose(1009, err.Error())
}
} else { // When using the origin data passed to this `Parse` func:
c.bytesCached = allocator.Malloc(len(data))
copy(c.bytesCached, data)
return err
}
} else { // The data bytes were all consumed:
// If the data bytes were cached, release the bytes and clear the cache.
if len(c.bytesCached) > 0 {
allocator.Free(c.bytesCached)
c.bytesCached = nil

if message != nil {
c.handleMessage(c.msgType, message)
message = nil
}
if frame != nil {
c.handleDataFrame(c.msgType, fin, frame)
frame = nil
}
if protocolMessage != nil {
c.handleProtocolMessage(opcode, protocolMessage)
protocolMessage = nil
}

// need more data
if !ok {
break
}
}

return err
return nil
}

// OnMessage .
Expand Down

0 comments on commit 985998b

Please sign in to comment.