Skip to content

Commit

Permalink
重构: bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed Jun 29, 2023
1 parent 2721a49 commit ab6b5fa
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 48 deletions.
14 changes: 9 additions & 5 deletions autobahn/autobahn-client-testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ import (
// https://github.com/snapview/tokio-tungstenite/blob/master/examples/autobahn-client.rs

const (
host = "ws://192.168.128.44:9003"
// host = "ws://127.0.0.1:9003"
// host = "ws://192.168.128.44:9003"
host = "ws://127.0.0.1:9003"
agent = "quickws"
)

Expand All @@ -25,9 +25,9 @@ func (e *echoHandler) OnOpen(c *quickws.Conn) {
}

func (e *echoHandler) OnMessage(c *quickws.Conn, op quickws.Opcode, msg []byte) {
// fmt.Println("OnMessage:", c, msg, op)
if op == quickws.Text || op == quickws.Binary {
if err := c.WriteTimeout(op, msg, 3*time.Second); err != nil {
// os.WriteFile("./debug.dat", msg, 0o644)
if err := c.WriteTimeout(op, msg, 1*time.Minute); err != nil {
fmt.Println("write fail:", err)
}
}
Expand Down Expand Up @@ -79,7 +79,11 @@ func getCaseCount() int {

func runTest(caseNo int) {
done := make(chan struct{})
c, err := quickws.Dial(fmt.Sprintf("%s/runCase?case=%d&agent=%s", host, caseNo, agent), quickws.WithClientCallback(&echoHandler{done: done}))
c, err := quickws.Dial(fmt.Sprintf("%s/runCase?case=%d&agent=%s", host, caseNo, agent),
quickws.WithClientReplyPing(),
quickws.WithClientDecompression(),
quickws.WithClientCallback(&echoHandler{done: done}),
)
if err != nil {
fmt.Println("Dial fail:", err)
return
Expand Down
3 changes: 2 additions & 1 deletion autobahn/autobahn-server-testsuite.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ func (e *echoHandler) OnClose(c *quickws.Conn, err error) {

// echo测试服务
func echo(w http.ResponseWriter, r *http.Request) {
c, err := quickws.Upgrade(w, r, quickws.WithServerReplyPing(),
c, err := quickws.Upgrade(w, r,
quickws.WithServerReplyPing(),
quickws.WithServerDecompression(),
quickws.WithServerIgnorePong(),
quickws.WithServerCallback(&echoHandler{}),
Expand Down
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ func (d *DialOption) Dial() (c *Conn, err error) {
if err != nil {
return nil, err
}
b2 := getBytes(len(b))
b2 := getBytes(len(b) + maxFrameHeaderSize)
copy(*b2, b)
fr = newBuffer(conn, b2)
fr.w = len(*b2)
fr.w = len(b)
}
// fmt.Println(brw.Reader.Buffered())
return newConn(conn, true, d.config, fr), nil
Expand Down
14 changes: 7 additions & 7 deletions client_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,49 +37,49 @@ func WithClientOnMessageFunc(cb OnMessageFunc) OptionClient {
}

// 配置tls.config
func WithTLSConfig(tls *tls.Config) OptionClient {
func WithClientTLSConfig(tls *tls.Config) OptionClient {
return func(o *DialOption) {
o.tlsConfig = tls
}
}

// 配置http.Header
func WithHTTPHeader(h http.Header) OptionClient {
func WithClientHTTPHeader(h http.Header) OptionClient {
return func(o *DialOption) {
o.Header = h
}
}

// 配置握手时的timeout
func WithDialTimeout(t time.Duration) OptionClient {
func WithClientDialTimeout(t time.Duration) OptionClient {
return func(o *DialOption) {
o.dialTimeout = t
}
}

// 配置自动回应ping frame, 当收到ping, 回一个pong
func WithReplyPing() OptionClient {
func WithClientReplyPing() OptionClient {
return func(o *DialOption) {
o.replyPing = true
}
}

// 配置解压缩
func WithDecompression() OptionClient {
func WithClientDecompression() OptionClient {
return func(o *DialOption) {
o.decompression = true
}
}

// 配置压缩
func WithCompression() OptionClient {
func WithClientCompression() OptionClient {
return func(o *DialOption) {
o.compression = true
}
}

// 配置压缩和解压缩
func WithDecompressAndCompress() OptionClient {
func WithClientDecompressAndCompress() OptionClient {
return func(o *DialOption) {
o.compression = true
o.decompression = true
Expand Down
28 changes: 15 additions & 13 deletions fixedreader.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ func (b *fixedReader) reset(buf *[]byte) {
b.buf = *buf
}

// 返回底层[]byte的长度
func (b *fixedReader) Len() int {
return len(b.buf)
}

func (b *fixedReader) ptr() *[]byte {
return b.p
}
Expand All @@ -62,8 +67,14 @@ func (b *fixedReader) bytes() []byte {
return b.buf
}

func (b *fixedReader) remainingLen() int {
return len(b.buf) - b.w
// 返回剩余可写的缓存区大小
func (b *fixedReader) writeCap() int {
return len(b.buf[b.w:])
}

// 返回剩余可用的缓存区大小
func (b *fixedReader) available() int64 {
return int64(len(b.buf[b.w:]) + b.r)
}

// 左移缓存区
Expand All @@ -77,23 +88,14 @@ func (b *fixedReader) leftMove() {
}

// 返回可写的缓存区
func (b *fixedReader) free() []byte {
r := b.r
copy(b.buf, b.buf[r:])
b.w -= r
b.r = 0
func (b *fixedReader) writeCapBytes() []byte {
return b.buf[b.w:]
}

func (b *fixedReader) availableBuf() *fixedReader {
func (b *fixedReader) cloneAvailable() *fixedReader {
return &fixedReader{rd: b.rd, buf: b.buf[b.w:]}
}

// 返回剩余可用的缓存区大小
func (b *fixedReader) available() int64 {
return int64(len(b.buf[b.w:]) + b.r)
}

func (b *fixedReader) Buffered() int { return b.w - b.r }

// 这和一般read接口中不一样
Expand Down
43 changes: 23 additions & 20 deletions frame.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,15 @@ type frame struct {
}

func readFrame(r *fixedReader, headArray *[maxFrameHeaderSize]byte) (f frame, err error) {
if r.remainingLen() < maxFrameHeaderSize && r.w-r.r < maxFrameHeaderSize {
// 如果剩余可写缓存区放不下一个frame header, 就把数据往前移动
// 所有的的buf分配都是paydload + frame head 的长度, 挪完之后,肯定是能放下一个frame header的
if r.Len()-r.r < maxFrameHeaderSize {
r.leftMove()
if r.Len() < maxFrameHeaderSize {
panic("readFrame r.Len() < maxFrameHeaderSize")
}
}

h, _, err := readHeader(r, headArray)
if err != nil {
return f, err
Expand All @@ -63,49 +69,46 @@ func readFrame(r *fixedReader, headArray *[maxFrameHeaderSize]byte) (f frame, er

// 已读取未处理的数据
readUnhandle := int64(r.w - r.r)
// 情况 1,需要读的长度 > 剩余可用空间(未写的+已经被读取走的)
if h.payloadLen-readUnhandle > r.available() {
// 取得旧的buf
// 1.取得旧的buf
oldBuf := r.ptr()
// 获取新的buf
// 2.获取新的buf
newBuf := getBytes(int(h.payloadLen) + maxFrameHeaderSize)
// 重置缓存区
// 3.重置缓存区
r.reset(newBuf)
// 将旧的buf放回池子里
// 4.将旧的buf放回池子里
putBytes(oldBuf)

// 情况 2。 空间是够的,需要挪一挪, 把已经读过的覆盖掉
} else if h.payloadLen-readUnhandle > int64(r.writeCap()) {
r.leftMove()
}

// 返回可写的缓存区, 把已经读取的数据去掉,这里是把frame header的数据去掉
payload := r.free()
// 返回可写的缓存区
payload := r.writeCapBytes()
// 前面的reset已经保证了,buffer的大小是够的
needRead := 0
if h.payloadLen-readUnhandle > 0 {
// 还需要读取的数据等于 h.payloadLen - (h.w - h.r)
needRead = int(h.payloadLen - readUnhandle)
}

if r.r != 0 {
panic("readFrame r != 0")
}
needRead := h.payloadLen - readUnhandle

if needRead > 0 {
// payload是一块干净可写的空间,使用needRead框下范围
payload = payload[:needRead]
// 新建一对新的r w指向尾部的内存区域
right := r.availableBuf()
right := r.cloneAvailable()
if _, err = io.ReadFull(right, payload); err != nil {
return f, err
}

// right 也有可能超读, 直接加上payload的长度,会把超读的数据给丢了
// 为什么会发生超读呢,right持的buf 会 >= payload的长度
r.w += right.w
}
r.r += int(h.payloadLen)

f.payload = r.bytes()[:h.payloadLen]
f.payload = r.bytes()[r.r : r.r+int(h.payloadLen)]
f.frameHeader = h
r.r += int(h.payloadLen)
if h.mask {
key := binary.LittleEndian.Uint32(h.maskValue[:])
// mask(f.payload, f.maskValue[:])
mask.Mask(f.payload, key)
}

Expand Down

0 comments on commit ab6b5fa

Please sign in to comment.