From 874a8edaa3939ed5d9aae10c1ea7036e8b538624 Mon Sep 17 00:00:00 2001 From: Ehsan Noureddin Moosa Date: Thu, 9 May 2024 17:26:42 +0300 Subject: [PATCH] refactor writeFunc --- kit/bridge_north.go | 5 ++-- kit/bridge_south.go | 14 +++++----- kit/conn.go | 1 + kit/ctx.go | 1 - kit/ctx_testkit.go | 27 ++++++------------ kit/edge_test.go | 25 ++++++++--------- kit/envelope.go | 2 +- std/gateways/fasthttp/bundle.go | 34 +++++----------------- std/gateways/fasthttp/conn_http.go | 24 ++++++++++++++++ std/gateways/fasthttp/conn_ws.go | 32 ++++++++++++++++++--- std/gateways/fastws/conn.go | 42 ++++++++++++++++++++++++---- std/gateways/fastws/gateway.go | 4 +-- std/gateways/silverhttp/bundle.go | 33 +--------------------- std/gateways/silverhttp/conn_http.go | 26 +++++++++++++++++ 14 files changed, 155 insertions(+), 115 deletions(-) diff --git a/kit/bridge_north.go b/kit/bridge_north.go index 36f04cbe..6cb6ba97 100644 --- a/kit/bridge_north.go +++ b/kit/bridge_north.go @@ -45,7 +45,7 @@ type Gateway interface { type GatewayDelegate interface { ConnDelegate // OnMessage must be called whenever a new message arrives. - OnMessage(c Conn, wf WriteFunc, msg []byte) + OnMessage(c Conn, msg []byte) } type ConnDelegate interface { @@ -84,10 +84,9 @@ func (n *northBridge) OnClose(connID uint64) { n.cd.OnClose(connID) } -func (n *northBridge) OnMessage(conn Conn, wf WriteFunc, msg []byte) { +func (n *northBridge) OnMessage(conn Conn, msg []byte) { n.wg.Add(1) ctx := n.acquireCtx(conn) - ctx.wf = wf ctx.sb = n.sb ctx.rawData = msg diff --git a/kit/bridge_south.go b/kit/bridge_south.go index d87af794..f3facac7 100644 --- a/kit/bridge_south.go +++ b/kit/bridge_south.go @@ -75,9 +75,9 @@ func (sb *southBridge) OnMessage(data []byte) error { sessionID: carrier.SessionID, serverID: sb.id, kv: map[string]string{}, + wf: sb.writeFunc, } ctx := sb.acquireCtx(conn) - ctx.wf = sb.writeFunc ctx.sb = sb switch carrier.Kind { @@ -262,12 +262,7 @@ func (sb *southBridge) genForwarderHandler(sel EdgeSelectorFunc) HandlerFunc { } } -func (sb *southBridge) writeFunc(conn Conn, e *Envelope) error { - c, ok := conn.(*clusterConn) - if !ok { - return ErrWritingToClusterConnection - } - +func (sb *southBridge) writeFunc(c *clusterConn, e *Envelope) error { ec := newEnvelopeCarrier( outgoingCarrier, c.sessionID, @@ -295,6 +290,7 @@ type clusterConn struct { kvMtx sync.Mutex kv map[string]string + wf func(c *clusterConn, e *Envelope) error } var _ Conn = (*clusterConn)(nil) @@ -311,6 +307,10 @@ func (c *clusterConn) Write(_ []byte) (int, error) { return 0, ErrWritingToClusterConnection } +func (c *clusterConn) WriteEnvelope(e *Envelope) error { + return c.wf(c, e) +} + func (c *clusterConn) Stream() bool { return c.stream } diff --git a/kit/conn.go b/kit/conn.go index 8ecc3474..c6c68b2d 100644 --- a/kit/conn.go +++ b/kit/conn.go @@ -5,6 +5,7 @@ type Conn interface { ConnID() uint64 ClientIP() string Write(data []byte) (int, error) + WriteEnvelope(e *Envelope) error Stream() bool Walk(fn func(key string, val string) bool) Get(key string) string diff --git a/kit/ctx.go b/kit/ctx.go index 11c01934..575cf914 100644 --- a/kit/ctx.go +++ b/kit/ctx.go @@ -43,7 +43,6 @@ type Context struct { hdr map[string]string conn Conn in *Envelope - wf WriteFunc modifiers []ModifierFunc err error statusCode int diff --git a/kit/ctx_testkit.go b/kit/ctx_testkit.go index 9a24842a..0b73c658 100644 --- a/kit/ctx_testkit.go +++ b/kit/ctx_testkit.go @@ -60,15 +60,6 @@ func (testCtx *TestContext) Run(stream bool) error { ctx.in. SetMsg(testCtx.inMsg). SetHdrMap(testCtx.inHdr) - ctx.wf = func(conn Conn, e *Envelope) error { - e.dontReuse() - tc := conn.(*testConn) //nolint:forcetypeassert - tc.Lock() - tc.out = append(tc.out, e) - tc.Unlock() - - return nil - } ctx.handlers = append(ctx.handlers, testCtx.handlers...) ctx.Next() @@ -85,15 +76,6 @@ func (testCtx *TestContext) RunREST() error { ctx.in. SetMsg(testCtx.inMsg). SetHdrMap(testCtx.inHdr) - ctx.wf = func(conn Conn, e *Envelope) error { - e.dontReuse() - tc := conn.(*testRESTConn) //nolint:forcetypeassert - tc.Lock() - tc.out = append(tc.out, e) - tc.Unlock() - - return nil - } ctx.handlers = append(ctx.handlers, testCtx.handlers...) ctx.Next() @@ -130,6 +112,15 @@ func (t *testConn) Write(_ []byte) (int, error) { return 0, nil } +func (t *testConn) WriteEnvelope(e *Envelope) error { + e.dontReuse() + t.Lock() + t.out = append(t.out, e) + t.Unlock() + + return nil +} + func (t *testConn) Stream() bool { return t.stream } diff --git a/kit/edge_test.go b/kit/edge_test.go index dba9ef71..332d1879 100644 --- a/kit/edge_test.go +++ b/kit/edge_test.go @@ -39,20 +39,7 @@ type testGateway struct { var _ kit.Gateway = (*testGateway)(nil) func (t *testGateway) Send(c *testConn, msg []byte) { - t.d.OnMessage( - c, - func(conn kit.Conn, e *kit.Envelope) error { - b, err := kit.MarshalMessage(e.GetMsg()) - if err != nil { - return err - } - - _, err = conn.Write(b) - - return err - }, - msg, - ) + t.d.OnMessage(c, msg) } func (t *testGateway) Start(_ context.Context, _ kit.GatewayStartConfig) error { @@ -241,6 +228,16 @@ func (t testConn) Write(data []byte) (int, error) { return t.buf.Write(data) } +func (t *testConn) WriteEnvelope(e *kit.Envelope) error { + b, err := kit.MarshalMessage(e.GetMsg()) + if err != nil { + return err + } + _, err = t.buf.Write(b) + + return err +} + func (t testConn) Read() ([]byte, error) { return io.ReadAll(t.buf) } diff --git a/kit/envelope.go b/kit/envelope.go index ff226308..ddcaf4eb 100644 --- a/kit/envelope.go +++ b/kit/envelope.go @@ -180,7 +180,7 @@ func (e *Envelope) Send() { } // Use WriteFunc to write the Envelope into the connection - e.ctx.Error(e.ctx.wf(e.conn, e)) + e.ctx.Error(e.conn.WriteEnvelope(e)) // Release the envelope e.release() diff --git a/std/gateways/fasthttp/bundle.go b/std/gateways/fasthttp/bundle.go index de5383d0..e2244767 100644 --- a/std/gateways/fasthttp/bundle.go +++ b/std/gateways/fasthttp/bundle.go @@ -204,7 +204,7 @@ func (b *bundle) genHTTPHandler(rd routeData) fasthttp.RequestHandler { c.ctx = ctx c.rd = &rd b.d.OnOpen(c) - b.d.OnMessage(c, b.writeFunc, ctx.PostBody()) + b.d.OnMessage(c, ctx.PostBody()) b.d.OnClose(c.ConnID()) b.connPool.Put(c) @@ -227,10 +227,11 @@ func (b *bundle) wsHandler(ctx *fasthttp.RequestCtx) { ctx, func(conn *websocket.Conn) { wsc := &wsConn{ - kv: map[string]string{}, - id: atomic.AddUint64(&b.wsNextID, 1), - clientIP: realip.FromRequest(ctx), - c: conn, + kv: map[string]string{}, + id: atomic.AddUint64(&b.wsNextID, 1), + clientIP: realip.FromRequest(ctx), + c: conn, + rpcOutFactory: b.rpcOutFactory, } b.d.OnOpen(wsc) for { @@ -249,7 +250,7 @@ func (b *bundle) wsHandler(ctx *fasthttp.RequestCtx) { } func (b *bundle) wsHandlerExec(buf *buf.Bytes, wsc *wsConn) { - b.d.OnMessage(wsc, b.writeFunc, *buf.Bytes()) + b.d.OnMessage(wsc, *buf.Bytes()) buf.Release() } @@ -335,28 +336,7 @@ func (b *bundle) writeFunc(conn kit.Conn, e *kit.Envelope) error { outC.Release() return err - case *httpConn: - var ( - data []byte - err error - ) - - data, err = kit.MarshalMessage(e.GetMsg()) - if err != nil { - return err - } - e.WalkHdr( - func(key string, val string) bool { - c.ctx.Response.Header.Set(key, val) - - return true - }, - ) - - c.ctx.SetBody(data) - - return nil default: panic("BUG!! incorrect connection") } diff --git a/std/gateways/fasthttp/conn_http.go b/std/gateways/fasthttp/conn_http.go index a417c6b1..a51d1389 100644 --- a/std/gateways/fasthttp/conn_http.go +++ b/std/gateways/fasthttp/conn_http.go @@ -70,6 +70,30 @@ func (c *httpConn) Write(data []byte) (int, error) { return len(data), nil } +func (c *httpConn) WriteEnvelope(e *kit.Envelope) error { + var ( + data []byte + err error + ) + + data, err = kit.MarshalMessage(e.GetMsg()) + if err != nil { + return err + } + + e.WalkHdr( + func(key string, val string) bool { + c.ctx.Response.Header.Set(key, val) + + return true + }, + ) + + c.ctx.SetBody(data) + + return nil +} + func (c *httpConn) Stream() bool { return false } diff --git a/std/gateways/fasthttp/conn_ws.go b/std/gateways/fasthttp/conn_ws.go index be12c9a2..bb5ccf46 100644 --- a/std/gateways/fasthttp/conn_ws.go +++ b/std/gateways/fasthttp/conn_ws.go @@ -8,10 +8,11 @@ import ( type wsConn struct { utils.SpinLock - kv map[string]string - id uint64 - clientIP string - c *websocket.Conn + kv map[string]string + id uint64 + clientIP string + c *websocket.Conn + rpcOutFactory kit.OutgoingRPCFactory } var _ kit.Conn = (*wsConn)(nil) @@ -47,6 +48,29 @@ func (w *wsConn) Write(data []byte) (int, error) { return len(data), nil } +func (w *wsConn) WriteEnvelope(e *kit.Envelope) error { + outC := w.rpcOutFactory() + outC.InjectMessage(e.GetMsg()) + outC.SetID(e.GetID()) + e.WalkHdr( + func(key string, val string) bool { + outC.SetHdr(key, val) + + return true + }, + ) + + data, err := outC.Marshal() + if err != nil { + return err + } + + _, err = w.Write(data) + outC.Release() + + return err +} + func (w *wsConn) Stream() bool { return true } diff --git a/std/gateways/fastws/conn.go b/std/gateways/fastws/conn.go index 7cb3222e..51bbf80f 100644 --- a/std/gateways/fastws/conn.go +++ b/std/gateways/fastws/conn.go @@ -1,6 +1,8 @@ package fastws import ( + "github.com/clubpay/ronykit/kit" + "github.com/clubpay/ronykit/kit/errors" "github.com/clubpay/ronykit/kit/utils" "github.com/gobwas/ws" "github.com/gobwas/ws/wsutil" @@ -16,15 +18,22 @@ type wsConn struct { r *wsutil.Reader w *wsutil.Writer handshakeDone bool + rpcOutFactory kit.OutgoingRPCFactory } -func newWebsocketConn(id uint64, c gnet.Conn) *wsConn { +var _ kit.Conn = (*wsConn)(nil) + +func newWebsocketConn( + id uint64, c gnet.Conn, + rpcOutFactory kit.OutgoingRPCFactory, +) *wsConn { wsc := &wsConn{ - w: wsutil.NewWriter(c, ws.StateServerSide, ws.OpText), - r: wsutil.NewReader(c, ws.StateServerSide), - id: id, - kv: map[string]string{}, - c: c, + w: wsutil.NewWriter(c, ws.StateServerSide, ws.OpText), + r: wsutil.NewReader(c, ws.StateServerSide), + id: id, + kv: map[string]string{}, + c: c, + rpcOutFactory: rpcOutFactory, } return wsc @@ -61,6 +70,27 @@ func (c *wsConn) Write(data []byte) (int, error) { return n, err } +func (c *wsConn) WriteEnvelope(e *kit.Envelope) error { + outC := c.rpcOutFactory() + outC.InjectMessage(e.GetMsg()) + outC.SetID(e.GetID()) + e.WalkHdr(func(key string, val string) bool { + outC.SetHdr(key, val) + + return true + }) + + data, err := outC.Marshal() + if err != nil { + return errors.Wrap(kit.ErrEncodeOutgoingMessageFailed, err) + } + + _, err = c.Write(data) + outC.Release() + + return err +} + func (c *wsConn) Stream() bool { return true } diff --git a/std/gateways/fastws/gateway.go b/std/gateways/fastws/gateway.go index 1213ab87..4b912e32 100644 --- a/std/gateways/fastws/gateway.go +++ b/std/gateways/fastws/gateway.go @@ -57,7 +57,7 @@ func (gw *gateway) writeFunc(conn kit.Conn, e *kit.Envelope) error { } func (gw *gateway) reactFunc(wsc kit.Conn, payload *buf.Bytes, n int) { - gw.b.d.OnMessage(wsc, gw.writeFunc, (*payload.Bytes())[:n]) + gw.b.d.OnMessage(wsc, (*payload.Bytes())[:n]) payload.Release() } @@ -80,7 +80,7 @@ func (gw *gateway) OnBoot(_ gnet.Engine) (action gnet.Action) { func (gw *gateway) OnShutdown(_ gnet.Engine) {} func (gw *gateway) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) { - wsc := newWebsocketConn(atomic.AddUint64(&gw.nextID, 1), c) + wsc := newWebsocketConn(atomic.AddUint64(&gw.nextID, 1), c, gw.b.rpcOutFactory) c.SetContext(wsc.id) gw.Lock() diff --git a/std/gateways/silverhttp/bundle.go b/std/gateways/silverhttp/bundle.go index 3b033c6f..6e25bb56 100644 --- a/std/gateways/silverhttp/bundle.go +++ b/std/gateways/silverhttp/bundle.go @@ -171,7 +171,7 @@ func (b *bundle) httpHandler(ctx *silverlining.Context) { c.ctx = ctx b.d.OnOpen(c) - b.d.OnMessage(c, b.httpWriteFunc, httpBody) + b.d.OnMessage(c, httpBody) b.d.OnClose(c.ConnID()) b.connPool.Put(c) @@ -221,34 +221,3 @@ func (b *bundle) httpDispatch(ctx *kit.Context, in []byte) (kit.ExecuteArg, erro Route: fmt.Sprintf("%s %s", routeData.Method, routeData.Path), }, nil } - -func (b *bundle) httpWriteFunc(c kit.Conn, e *kit.Envelope) error { - rc, ok := c.(*httpConn) - if !ok { - panic("BUG!! incorrect connection") - } - - var ( - data []byte - err error - ) - - data, err = kit.MarshalMessage(e.GetMsg()) - if err != nil { - return err - } - - resHdr := rc.ctx.ResponseHeaders() - e.WalkHdr( - func(key string, val string) bool { - resHdr.Set(key, val) - - return true - }, - ) - - rc.ctx.SetContentLength(len(data)) - _, err = rc.ctx.Write(data) - - return err -} diff --git a/std/gateways/silverhttp/conn_http.go b/std/gateways/silverhttp/conn_http.go index e494bb87..0417b447 100644 --- a/std/gateways/silverhttp/conn_http.go +++ b/std/gateways/silverhttp/conn_http.go @@ -70,6 +70,32 @@ func (c *httpConn) Write(data []byte) (int, error) { return len(data), err } +func (c *httpConn) WriteEnvelope(e *kit.Envelope) error { + var ( + data []byte + err error + ) + + data, err = kit.MarshalMessage(e.GetMsg()) + if err != nil { + return err + } + + resHdr := c.ctx.ResponseHeaders() + e.WalkHdr( + func(key string, val string) bool { + resHdr.Set(key, val) + + return true + }, + ) + + c.ctx.SetContentLength(len(data)) + _, err = c.ctx.Write(data) + + return err +} + func (c *httpConn) Stream() bool { return false }