diff --git a/callback.go b/callback.go index 240a5af..b1a091e 100644 --- a/callback.go +++ b/callback.go @@ -21,6 +21,10 @@ type ( } ) +type ( + OnOpenFunc func(*Conn) +) + type DefCallback struct{} func (defcallback *DefCallback) OnOpen(_ *Conn) { @@ -32,7 +36,7 @@ func (defcallback *DefCallback) OnMessage(_ *Conn, _ Opcode, _ []byte) { func (defcallback *DefCallback) OnClose(_ *Conn, _ error) { } -// 只设置OnMessage +// 只设置OnMessage, 和OnClose互斥 type OnMessageFunc func(*Conn, Opcode, []byte) func (o OnMessageFunc) OnOpen(_ *Conn) { @@ -45,7 +49,7 @@ func (o OnMessageFunc) OnMessage(c *Conn, op Opcode, data []byte) { func (o OnMessageFunc) OnClose(_ *Conn, _ error) { } -// 只设置OnClose +// 只设置OnClose, 和OnMessage互斥 type OnCloseFunc func(*Conn, error) func (o OnCloseFunc) OnOpen(_ *Conn) { @@ -57,3 +61,27 @@ func (o OnCloseFunc) OnMessage(_ *Conn, _ Opcode, _ []byte) { func (o OnCloseFunc) OnClose(c *Conn, err error) { o(c, err) } + +type funcToCallback struct { + onOpen func(*Conn) + onMessage func(*Conn, Opcode, []byte) + onClose func(*Conn, error) +} + +func (f *funcToCallback) OnOpen(c *Conn) { + if f.onOpen != nil { + f.onOpen(c) + } +} + +func (f *funcToCallback) OnMessage(c *Conn, op Opcode, data []byte) { + if f.onMessage != nil { + f.onMessage(c, op, data) + } +} + +func (f *funcToCallback) OnClose(c *Conn, err error) { + if f.onClose != nil { + f.onClose(c, err) + } +} diff --git a/common_options.go b/common_options.go index 5c95353..27eb848 100644 --- a/common_options.go +++ b/common_options.go @@ -18,6 +18,28 @@ import ( "unicode/utf8" ) +// 0. CallbackFunc +func WithClientCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ClientOption { + return func(o *DialOption) { + o.Callback = &funcToCallback{ + onOpen: open, + onMessage: m, + onClose: c, + } + } +} + +// 配置服务端回调函数 +func WithServerCallbackFunc(open OnOpenFunc, m OnMessageFunc, c OnCloseFunc) ServerOption { + return func(o *ConnOption) { + o.Callback = &funcToCallback{ + onOpen: open, + onMessage: m, + onClose: c, + } + } +} + // 1. callback // 配置客户端callback func WithClientCallback(cb Callback) ClientOption { diff --git a/common_options_test.go b/common_options_test.go index c2eaad8..b857d70 100644 --- a/common_options_test.go +++ b/common_options_test.go @@ -45,6 +45,58 @@ func (defcallback *testServerOptionReadTimeout) OnClose(c *Conn, err error) { // 测试客户端和服务端都有的配置项 func Test_CommonOption(t *testing.T) { + t.Run("0.server.local: WithClientCallbackFunc", func(t *testing.T) { + run := int32(0) + done := make(chan bool, 1) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := Upgrade(w, r, WithServerTCPDelay(), WithServerOnMessageFunc(func(c *Conn, mt Opcode, payload []byte) { + c.WriteMessage(mt, payload) + atomic.AddInt32(&run, int32(1)) + })) + if err != nil { + t.Error(err) + } + c.StartReadLoop() + })) + + defer ts.Close() + + messageDone := make(chan bool, 1) + url := strings.ReplaceAll(ts.URL, "http", "ws") + clientRun := int32(0) + con, err := Dial(url, WithClientCallbackFunc(func(c *Conn) { + atomic.AddInt32(&clientRun, 10) + }, func(c *Conn, mt Opcode, payload []byte) { + atomic.AddInt32(&clientRun, 100) + messageDone <- true + }, func(c *Conn, err error) { + atomic.AddInt32(&clientRun, 1000) + done <- true + })) + if err != nil { + t.Error(err) + } + defer con.Close() + + con.WriteMessage(Binary, []byte("hello")) + con.StartReadLoop() + for i := 0; i < 2; i++ { + select { + case <-messageDone: + con.Close() + case <-done: + case <-time.After(100 * time.Millisecond): + } + } + if atomic.LoadInt32(&run) != 1 { + t.Error("not run server:method fail") + } + + if atomic.LoadInt32(&clientRun) != 1110 { + t.Errorf("not run client:method fail:%d, need:1110\n", atomic.LoadInt32(&clientRun)) + } + }) + t.Run("2.server.local: WithServerTCPDelay", func(t *testing.T) { run := int32(0) done := make(chan bool, 1) @@ -972,9 +1024,10 @@ func Test_CommonOption(t *testing.T) { run := int32(0) data := make(chan string, 1) ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := Upgrade(w, r, WithServerDecompressAndCompress(), WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { - c.WriteMessage(op, payload) - })) + c, err := Upgrade(w, r, WithServerDecompressAndCompress(), + WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { + c.WriteMessage(op, payload) + })) if err != nil { t.Error(err) } @@ -1360,38 +1413,71 @@ func Test_CommonOption(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := Upgrade(w, r, WithServerDecompressAndCompress(), - WithServerBufioParseMode(), - WithServerOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { + // WithServerBufioParseMode(), + WithServerCallbackFunc(nil, func(c *Conn, op Opcode, payload []byte) { + if op != Binary { + t.Error("opcode error") + } c.WriteMessage(op, payload) - atomic.AddInt32(&run, int32(1)) - data <- string(payload) - })) + }, func(c *Conn, err error) { + // t.Errorf("%T\n", err) + }, + )) if err != nil { t.Error(err) } + + if !c.compression { + t.Error("compression fail") + } + + if !c.decompression { + t.Error("compression fail") + } c.StartReadLoop() })) defer ts.Close() url := strings.ReplaceAll(ts.URL, "http", "ws") - con, err := Dial(url, WithClientDecompressAndCompress(), - WithClientDecompression(), + con, err := Dial(url, + WithClientDecompressAndCompress(), WithClientMaxDelayWriteDuration(10*time.Millisecond), WithClientMaxDelayWriteNum(3), WithClientWindowsParseMode(), WithClientDelayWriteInitBufferSize(4096), WithClientOnMessageFunc(func(c *Conn, op Opcode, payload []byte) { - c.WriteMessageDelay(op, []byte("hello")) - c.WriteMessageDelay(op, []byte("hello")) - c.WriteMessageDelay(op, []byte("hello")) + if op != Binary { + t.Error("opcode error") + } + err := c.WriteMessageDelay(op, []byte("hello")) + if err != nil { + t.Error(err) + } + err = c.WriteMessageDelay(op, []byte("hello")) + if err != nil { + t.Error(err) + } + err = c.WriteMessageDelay(op, []byte("hello")) + if err != nil { + t.Error(err) + } + data <- "hello" + atomic.AddInt32(&run, int32(1)) })) if err != nil { t.Error(err) } defer con.Close() - con.WriteMessage(Binary, []byte("hello")) + if !con.compression { + t.Error("not compression:method fail") + } + err = con.WriteMessage(Binary, []byte("hello")) + if err != nil { + t.Error(err) + } + con.StartReadLoop() select { case d := <-data: