Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

websocket: simplify cache append and concurrent free logic #439

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 70 additions & 96 deletions nbhttp/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,16 @@ ErrExit:
c.Close()
}

func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool, bool, error) {
func (c *Conn) nextFrame() (int, MessageType, []byte, bool, bool, bool, error) {
var (
opcode MessageType
body []byte
ok, fin, res1, res2, res3 bool
err error
data = c.bytesCached
l = int64(len(data))
headLen = int64(2)
total int64
)
if l >= 2 {
opcode = MessageType(data[0] & 0xF)
Expand All @@ -297,38 +299,41 @@ func (c *Conn) nextFrame(data []byte) ([]byte, MessageType, []byte, bool, bool,
}

if c.isMessageTooLarge(len(c.message) + int(bodyLen)) {
return data, 0, nil, false, false, false, ErrMessageTooLarge
return 0, 0, nil, false, false, false, ErrMessageTooLarge
}

if (bodyLen > maxControlFramePayloadSize) &&
((opcode == PingMessage) || (opcode == PongMessage) || (opcode == CloseMessage)) {
return data, 0, nil, false, false, false, ErrControlMessageTooBig
return 0, 0, nil, false, false, false, ErrControlMessageTooBig
}

if bodyLen >= 0 {
masked := (data[1] & 0x80) != 0
if masked {
headLen += 4
}
total := headLen + bodyLen
total = headLen + bodyLen
if l >= total {
body = data[headLen:total]
if masked {
maskXOR(body, data[headLen-4:headLen])
}

ok = true
data = data[total:l]
err = c.validFrame(opcode, fin, res1, res2, res3, c.expectingFragments)
}
}
}

return data, opcode, body, ok, fin, res1, err
return int(total), opcode, body, ok, fin, res1, err
}

// Read .
func (c *Conn) Parse(data []byte) error {
if len(data) == 0 {
return nil
}

c.mux.Lock()
if c.closed {
c.mux.Unlock()
Expand All @@ -341,12 +346,12 @@ func (c *Conn) Parse(data []byte) error {
return nbhttp.ErrTooLong
}

var appended = false
var allocator = c.Engine.BodyAllocator
if len(c.bytesCached) > 0 {
if len(c.bytesCached) == 0 {
c.bytesCached = allocator.Malloc(len(data))
copy(c.bytesCached, data)
} else {
c.bytesCached = allocator.Append(c.bytesCached, data...)
data = c.bytesCached
appended = true
}
c.mux.Unlock()

Expand All @@ -357,6 +362,7 @@ func (c *Conn) Parse(data []byte) error {
var protocolMessage []byte
var opcode MessageType
var ok, fin, compress bool
var totalFrameSize int

releaseBuf := func() {
if len(frame) > 0 {
Expand All @@ -378,31 +384,30 @@ func (c *Conn) Parse(data []byte) error {
err = net.ErrClosed
return
}
data, opcode, body, ok, fin, compress, err = c.nextFrame(data)
totalFrameSize, opcode, body, ok, fin, compress, err = c.nextFrame()
if err != nil {
return
}
if !ok {
return
}

bl := len(body)
switch opcode {
case FragmentMessage, TextMessage, BinaryMessage:
if c.msgType == 0 {
c.msgType = opcode
c.compress = compress
}
bl := len(body)
if c.dataFrameHandler != nil {
if bl > 0 {
frame = allocator.Malloc(bl)
copy(frame, body)
}
if c.msgType == TextMessage && len(frame) > 0 && !c.Engine.CheckUtf8(frame) {
c.Conn.Close()
err = ErrInvalidUtf8
return
}
if bl > 0 && c.dataFrameHandler != nil {
frame = allocator.Malloc(bl)
copy(frame, body)
// if compressed, should check utf8 after decompressed the whole message.
// if c.msgType == TextMessage && len(frame) > 0 && !c.Engine.CheckUtf8(frame) {
// c.Conn.Close()
// err = ErrInvalidUtf8
// return
// }
}
if c.messageHandler != nil {
if bl > 0 {
Expand All @@ -416,36 +421,6 @@ func (c *Conn) Parse(data []byte) error {
if fin {
message = c.message
c.message = nil
}
}
case PingMessage, PongMessage, CloseMessage:
if len(body) > 0 {
protocolMessage = allocator.Malloc(len(body))
copy(protocolMessage, body)
}
default:
err = ErrInvalidFragmentMessage
return
}
}()

if err != nil {
releaseBuf()
if errors.Is(err, ErrMessageTooLarge) || errors.Is(err, ErrControlMessageTooBig) {
c.WriteClose(1009, err.Error())
}
return err
}

if ok {
switch opcode {
case FragmentMessage, TextMessage, BinaryMessage:
if c.dataFrameHandler != nil {
c.handleDataFrame(c.msgType, fin, frame)
frame = nil
}
if fin {
if c.messageHandler != nil {
if c.compress {
var b []byte
var rc io.ReadCloser
Expand All @@ -460,64 +435,63 @@ func (c *Conn) Parse(data []byte) error {
rc.Close()
if err != nil {
releaseBuf()
return err
return
}
}
c.handleMessage(c.msgType, message)
message = nil
c.compress = false
c.expectingFragments = false
c.msgType = 0
} else {
c.expectingFragments = true
}
c.compress = false
c.expectingFragments = false
c.msgType = 0
} else {
c.expectingFragments = true
}
case PingMessage, PongMessage, CloseMessage:
c.handleProtocolMessage(opcode, protocolMessage)
protocolMessage = nil
if bl > 0 {
protocolMessage = allocator.Malloc(len(body))
copy(protocolMessage, body)
}
default:
releaseBuf()
return ErrInvalidFragmentMessage
err = ErrInvalidFragmentMessage
return
}
} else {
goto Exit
}

if len(data) == 0 {
goto Exit
}
}
l := len(c.bytesCached)
if l == totalFrameSize {
c.Engine.BodyAllocator.Free(c.bytesCached)
c.bytesCached = nil
} else {
copy(c.bytesCached, c.bytesCached[totalFrameSize:l])
c.bytesCached = c.bytesCached[:l-totalFrameSize]
}
}()

Exit:
releaseBuf()
c.mux.Lock()
defer c.mux.Unlock()
if c.closed {
return net.ErrClosed
}
// The data bytes were not all consumed, need to recache the current bytes left:
if len(data) > 0 {
// The data bytes were appended to the tail of the previous chaced data:
if appended {
// If data bytes were consumed, move data to the head of the cached bytes,
// else the data is same as the cached bytes, nothing to do.
if len(data) < len(c.bytesCached) {
c.bytesCached = c.bytesCached[:len(data)]
copy(c.bytesCached, data)
if err != nil {
if errors.Is(err, ErrMessageTooLarge) || errors.Is(err, ErrControlMessageTooBig) {
c.WriteClose(1009, err.Error())
}
} else { // When using the origin data passed to this `Parse` func:
c.bytesCached = allocator.Malloc(len(data))
copy(c.bytesCached, data)
return err
}
} else { // The data bytes were all consumed:
// If the data bytes were cached, release the bytes and clear the cache.
if len(c.bytesCached) > 0 {
allocator.Free(c.bytesCached)
c.bytesCached = nil

if message != nil {
c.handleMessage(c.msgType, message)
message = nil
}
if frame != nil {
c.handleDataFrame(c.msgType, fin, frame)
frame = nil
}
if protocolMessage != nil {
c.handleProtocolMessage(opcode, protocolMessage)
protocolMessage = nil
}

// need more data
if !ok {
break
}
}

return err
return nil
}

// OnMessage .
Expand Down
Loading