Skip to content

Commit

Permalink
+测试代码
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed Aug 18, 2023
1 parent 6139783 commit 5127496
Show file tree
Hide file tree
Showing 5 changed files with 191 additions and 2 deletions.
14 changes: 14 additions & 0 deletions callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func (defcallback *DefCallback) OnMessage(_ *Conn, _ Opcode, _ []byte) {
func (defcallback *DefCallback) OnClose(_ *Conn, _ error) {
}

// 只设置OnMessage
type OnMessageFunc func(*Conn, Opcode, []byte)

func (o OnMessageFunc) OnOpen(_ *Conn) {
Expand All @@ -43,3 +44,16 @@ func (o OnMessageFunc) OnMessage(c *Conn, op Opcode, data []byte) {

func (o OnMessageFunc) OnClose(_ *Conn, _ error) {
}

// 只设置OnClose
type OnCloseFunc func(*Conn, error)

func (o OnCloseFunc) OnOpen(_ *Conn) {
}

func (o OnCloseFunc) OnMessage(_ *Conn, _ Opcode, _ []byte) {
}

func (o OnCloseFunc) OnClose(c *Conn, err error) {
o(c, err)
}
15 changes: 15 additions & 0 deletions common_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,18 @@ func WithClientReadTimeout(t time.Duration) ClientOption {
o.readTimeout = t
}
}

// 17。 只配置OnClose
// 17.1 配置服务端OnClose
func WithServerOnCloseFunc(onClose func(c *Conn, err error)) ServerOption {
return func(o *ConnOption) {
o.Callback = OnCloseFunc(onClose)
}
}

// 17.2 配置客户端OnClose
func WithClientOnCloseFunc(onClose func(c *Conn, err error)) ClientOption {
return func(o *DialOption) {
o.Callback = OnCloseFunc(onClose)
}
}
42 changes: 42 additions & 0 deletions common_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1463,4 +1463,46 @@ func Test_CommonOption(t *testing.T) {
t.Error("not run server:method fail")
}
})

t.Run("17.client.WithClientOnCloseFunc", func(t *testing.T) {
run := int32(0)
data := make(chan string, 1)
upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerEnableUTF8Check(), WithServerOnCloseFunc(func(c *Conn, err error) {
c.Close()
}))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.Upgrade(w, r)
if err != nil {
t.Error(err)
}
c.StartReadLoop()
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientDisableBufioClearHack(),
WithClientEnableUTF8Check(), WithClientOnCloseFunc(func(c *Conn, err error) {
atomic.AddInt32(&run, 1)
data <- err.Error()
}))
if err != nil {
t.Error(err)
}
defer con.Close()
// 这里必须要报错
err = con.WriteMessage(Text, []byte("hello"))
if err != nil {
t.Error("not error")
}
con.StartReadLoop()
select {
case _ = <-data:
case <-time.After(500 * time.Millisecond):
}

if atomic.LoadInt32(&run) != 0 {
t.Error("not run server:method fail")
}
})
}
4 changes: 2 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ func (c *Conn) readLoop() error {
// bufio 模式才会使用payload
payload = *bytespool.GetBytes(1024 + enum.MaxFrameHeaderSize)
}

for {

// 从网络读取数据
Expand Down Expand Up @@ -462,8 +463,7 @@ func (c *Conn) WriteControl(op Opcode, data []byte) (err error) {
// 写分段数据, 目前主要是单元测试使用
func (c *Conn) writeFragment(op Opcode, writeBuf []byte, maxFragment int /*单个段最大size*/) (err error) {
if len(writeBuf) < maxFragment {
c.WriteMessage(op, writeBuf)
return
return c.WriteMessage(op, writeBuf)
}

if op == opcode.Text {
Expand Down
118 changes: 118 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,124 @@ func TestFragmentFrame(t *testing.T) {
t.Error("not run server:method fail")
}
})

t.Run("FragmentFrame-Small-Buffer", func(t *testing.T) {
run := int32(0)
data := make(chan string, 1)
upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) {
c.WriteMessage(op, payload)
}))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.Upgrade(w, r)
if err != nil {
t.Error(err)
}
c.StartReadLoop()
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientDisableBufioClearHack(), WithClientEnableUTF8Check(),
WithClientDecompressAndCompress(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) {
atomic.AddInt32(&run, int32(1))
data <- string(payload)
}))
if err != nil {
t.Error(err)
}
defer con.Close()

sendData := []byte("hell")
// 这里必须要报错
err = con.writeFragment(Text, sendData, 5)
if err != nil {
t.Errorf("error:%v", err)
}

con.StartReadLoop()

select {
case d := <-data:
if d != string(sendData) {
t.Errorf("write message or read message fail:got:%s, need:hello\n", d)
}
case <-time.After(1000 * time.Millisecond):
}
if atomic.LoadInt32(&run) != 1 {
t.Error("not run server:method fail")
}
})

t.Run("FragmentFrame-Client-Not-UTF8", func(t *testing.T) {
upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) {
c.WriteMessage(op, payload)
}))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.Upgrade(w, r)
if err != nil {
t.Error(err)
}
c.StartReadLoop()
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientDisableBufioClearHack(), WithClientEnableUTF8Check(),
WithClientDecompressAndCompress(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) {
}))
if err != nil {
t.Error(err)
}
defer con.Close()

// 这里必须要报错
err = con.writeFragment(Text, []byte{128, 129, 130, 131}, 1)
if err == nil {
t.Error("not error")
}
})

t.Run("FragmentFrame-Server-Not-UTF8", func(t *testing.T) {
run := int32(0)
data := make(chan string, 1)
upgrade := NewUpgrade(WithServerBufioParseMode(), WithServerEnableUTF8Check(), WithServerOnCloseFunc(func(c *Conn, err error) {
data <- err.Error()
}))
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := upgrade.Upgrade(w, r)
if err != nil {
t.Error(err)
}
c.StartReadLoop()
}))

defer ts.Close()

url := strings.ReplaceAll(ts.URL, "http", "ws")
con, err := Dial(url, WithClientDisableBufioClearHack(),
WithClientDecompressAndCompress(), WithClientOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) {
}))
if err != nil {
t.Error(err)
}
defer con.Close()
// 这里必须要报错
err = con.writeFragment(Text, []byte{128, 129, 130, 131}, 1)
if err != nil {
t.Error("error")
}
con.StartReadLoop()
select {
case _ = <-data:
case <-time.After(500 * time.Millisecond):
}

if atomic.LoadInt32(&run) != 0 {
t.Error("not run server:method fail")
}
})
}

type testPingPongCloseHandler struct {
Expand Down

0 comments on commit 5127496

Please sign in to comment.