Skip to content

Commit

Permalink
+测试代码
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed Aug 16, 2023
1 parent e133eeb commit 1d8e38f
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 18 deletions.
85 changes: 69 additions & 16 deletions common_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import (
"time"
)

var badUTF8 = []byte{128, 129, 130, 131}

// 测试客户端和服务端都有的配置项
func Test_CommonOption(t *testing.T) {
t.Run("2.server.local: WithServerTCPDelay", func(t *testing.T) {
Expand Down Expand Up @@ -153,10 +155,12 @@ func Test_CommonOption(t *testing.T) {
}
defer con.Close()

err = con.WriteMessage(Text, []byte{1, 2, 3, 4})
if err != nil {
t.Error(err)
err = con.WriteMessage(Text, badUTF8)
if err == nil {
t.Error("写入非法utf8数据,没有报错")
}

con.WriteMessage(Binary, badUTF8)
select {
case <-done:
case <-time.After(1000 * time.Millisecond):
Expand Down Expand Up @@ -192,10 +196,12 @@ func Test_CommonOption(t *testing.T) {
}
defer con.Close()

err = con.WriteMessage(Text, []byte{1, 2, 3, 4})
if err != nil {
t.Error(err)
err = con.WriteMessage(Text, badUTF8)
if err == nil {
t.Error("写入非法utf8数据,没有报错")
}
_ = con.WriteMessage(Binary, badUTF8)

select {
case <-done:
case <-time.After(1000 * time.Millisecond):
Expand All @@ -206,7 +212,44 @@ func Test_CommonOption(t *testing.T) {
})

t.Run("3.client: WithClientDisableUTF8Check", func(t *testing.T) {
// 3.server.local: WithServerDisableUTF8Check 已经测试过客户端,所以这里留空
// 客户端不检查utf8, 服务端检查utf8
run := int32(0)
done := make(chan bool, 1)
upgrade := NewUpgrade(WithServerEnableUTF8Check(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) {
c.WriteMessage(op, payload)
atomic.AddInt32(&run, int32(1))
done <- true
}))

ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.Upgrade(w, r)
if err != nil {
t.Error(err)
}
c.StartReadLoop()
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientCallback(&testDefaultCallback{}))
if err != nil {
t.Error(err)
}
defer con.Close()

err = con.WriteMessage(Text, badUTF8)
if err != nil {
t.Error("关闭utf8检查, 写入非法utf8数据,不报错")
}

select {
case <-done:
case <-time.After(100 * time.Millisecond):
}
if atomic.LoadInt32(&run) != 0 {
t.Error("not run server:method fail")
}
})

t.Run("4.server.local: WithServerOnMessageFunc", func(t *testing.T) {
Expand Down Expand Up @@ -1039,9 +1082,11 @@ func Test_CommonOption(t *testing.T) {
run := int32(0)
data := make(chan string, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := Upgrade(w, r, WithServerWindowsParseMode(), WithServerBufioMultipleTimesPayloadSize(-1) /*这里写-1只是为了代码覆盖度测试*/, WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) {
c.WriteMessage(op, payload)
}))
c, err := Upgrade(w, r, WithServerWindowsParseMode(),
WithServerBufioMultipleTimesPayloadSize(6), /*这里写-1只是为了代码覆盖度测试*/
WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) {
c.WriteMessage(op, payload)
}))
if err != nil {
t.Error(err)
}
Expand Down Expand Up @@ -1245,11 +1290,14 @@ func Test_CommonOption(t *testing.T) {
run := int32(0)
data := make(chan string, 1)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := Upgrade(w, r, WithServerBufioParseMode(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) {
c.WriteMessage(op, payload)
atomic.AddInt32(&run, int32(1))
data <- string(payload)
}))
c, err := Upgrade(w, r,
WithServerDecompressAndCompress(),
WithServerBufioParseMode(),
WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) {
c.WriteMessage(op, payload)
atomic.AddInt32(&run, int32(1))
data <- string(payload)
}))
if err != nil {
t.Error(err)
}
Expand All @@ -1259,7 +1307,12 @@ func Test_CommonOption(t *testing.T) {
defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientMaxDelayWriteDuration(10*time.Millisecond), WithClientMaxDelayWriteNum(3), WithClientWindowsParseMode(), WithClientDelayWriteInitBufferSize(4096),
con, err := Dial(url, WithClientDecompressAndCompress(),
WithClientDecompression(),
WithClientMaxDelayWriteDuration(10*time.Millisecond),
WithClientMaxDelayWriteNum(3),
WithClientWindowsParseMode(),
WithClientDelayWriteInitBufferSize(4096),
WithClientOnMessageFunc(func(c *Conn, op Opcode, payload []byte) {
c.WriteMessageDelay(op, []byte("hello"))
c.WriteMessageDelay(op, []byte("hello"))
Expand Down
4 changes: 2 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ func (c *Conn) readLoop() error {
// TODO sync.Pool管理
(*bufio2.Reader2)(unsafe.Pointer(c.br)).ResetBuf(make([]byte, newSize))
}
// bufio 模式才会使用payload
payload = *bytespool.GetBytes(1024 + enum.MaxFrameHeaderSize)
}
for {
Expand Down Expand Up @@ -574,7 +575,7 @@ func (c *Conn) WriteMessageDelay(op Opcode, writeBuf []byte) (err error) {
// 初始化缓存
if c.delayBuf == nil && c.delayWriteInitBufferSize > 0 {

// TODO sync.Pool管理下, 如果size是1k 2k 3k
// TODO: sync.Pool管理下, 如果size是1k 2k 3k
delayBuf := make([]byte, 0, c.delayWriteInitBufferSize)
c.delayBuf = bytes.NewBuffer(delayBuf)
}
Expand Down Expand Up @@ -604,6 +605,5 @@ func (c *Conn) WriteMessageDelay(op Opcode, writeBuf []byte) (err error) {
}
c.delayNum++ // 对记数计+1
c.delayMu.Unlock()
// }()
return nil
}

0 comments on commit 1d8e38f

Please sign in to comment.