Skip to content

Commit

Permalink
refactor writeFunc
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsannm committed May 9, 2024
1 parent a18b030 commit 874a8ed
Show file tree
Hide file tree
Showing 14 changed files with 155 additions and 115 deletions.
5 changes: 2 additions & 3 deletions kit/bridge_north.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type Gateway interface {
type GatewayDelegate interface {
ConnDelegate
// OnMessage must be called whenever a new message arrives.
OnMessage(c Conn, wf WriteFunc, msg []byte)
OnMessage(c Conn, msg []byte)
}

type ConnDelegate interface {
Expand Down Expand Up @@ -84,10 +84,9 @@ func (n *northBridge) OnClose(connID uint64) {
n.cd.OnClose(connID)
}

func (n *northBridge) OnMessage(conn Conn, wf WriteFunc, msg []byte) {
func (n *northBridge) OnMessage(conn Conn, msg []byte) {
n.wg.Add(1)
ctx := n.acquireCtx(conn)
ctx.wf = wf
ctx.sb = n.sb
ctx.rawData = msg

Expand Down
14 changes: 7 additions & 7 deletions kit/bridge_south.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ func (sb *southBridge) OnMessage(data []byte) error {
sessionID: carrier.SessionID,
serverID: sb.id,
kv: map[string]string{},
wf: sb.writeFunc,
}
ctx := sb.acquireCtx(conn)
ctx.wf = sb.writeFunc
ctx.sb = sb

switch carrier.Kind {
Expand Down Expand Up @@ -262,12 +262,7 @@ func (sb *southBridge) genForwarderHandler(sel EdgeSelectorFunc) HandlerFunc {
}
}

func (sb *southBridge) writeFunc(conn Conn, e *Envelope) error {
c, ok := conn.(*clusterConn)
if !ok {
return ErrWritingToClusterConnection
}

func (sb *southBridge) writeFunc(c *clusterConn, e *Envelope) error {
ec := newEnvelopeCarrier(
outgoingCarrier,
c.sessionID,
Expand Down Expand Up @@ -295,6 +290,7 @@ type clusterConn struct {

kvMtx sync.Mutex
kv map[string]string
wf func(c *clusterConn, e *Envelope) error
}

var _ Conn = (*clusterConn)(nil)
Expand All @@ -311,6 +307,10 @@ func (c *clusterConn) Write(_ []byte) (int, error) {
return 0, ErrWritingToClusterConnection
}

func (c *clusterConn) WriteEnvelope(e *Envelope) error {
return c.wf(c, e)
}

func (c *clusterConn) Stream() bool {
return c.stream
}
Expand Down
1 change: 1 addition & 0 deletions kit/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type Conn interface {
ConnID() uint64
ClientIP() string
Write(data []byte) (int, error)
WriteEnvelope(e *Envelope) error
Stream() bool
Walk(fn func(key string, val string) bool)
Get(key string) string
Expand Down
1 change: 0 additions & 1 deletion kit/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ type Context struct {
hdr map[string]string
conn Conn
in *Envelope
wf WriteFunc
modifiers []ModifierFunc
err error
statusCode int
Expand Down
27 changes: 9 additions & 18 deletions kit/ctx_testkit.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,6 @@ func (testCtx *TestContext) Run(stream bool) error {
ctx.in.
SetMsg(testCtx.inMsg).
SetHdrMap(testCtx.inHdr)
ctx.wf = func(conn Conn, e *Envelope) error {
e.dontReuse()
tc := conn.(*testConn) //nolint:forcetypeassert
tc.Lock()
tc.out = append(tc.out, e)
tc.Unlock()

return nil
}
ctx.handlers = append(ctx.handlers, testCtx.handlers...)
ctx.Next()

Expand All @@ -85,15 +76,6 @@ func (testCtx *TestContext) RunREST() error {
ctx.in.
SetMsg(testCtx.inMsg).
SetHdrMap(testCtx.inHdr)
ctx.wf = func(conn Conn, e *Envelope) error {
e.dontReuse()
tc := conn.(*testRESTConn) //nolint:forcetypeassert
tc.Lock()
tc.out = append(tc.out, e)
tc.Unlock()

return nil
}
ctx.handlers = append(ctx.handlers, testCtx.handlers...)
ctx.Next()

Expand Down Expand Up @@ -130,6 +112,15 @@ func (t *testConn) Write(_ []byte) (int, error) {
return 0, nil
}

func (t *testConn) WriteEnvelope(e *Envelope) error {
e.dontReuse()
t.Lock()
t.out = append(t.out, e)
t.Unlock()

return nil
}

func (t *testConn) Stream() bool {
return t.stream
}
Expand Down
25 changes: 11 additions & 14 deletions kit/edge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,7 @@ type testGateway struct {
var _ kit.Gateway = (*testGateway)(nil)

func (t *testGateway) Send(c *testConn, msg []byte) {
t.d.OnMessage(
c,
func(conn kit.Conn, e *kit.Envelope) error {
b, err := kit.MarshalMessage(e.GetMsg())
if err != nil {
return err
}

_, err = conn.Write(b)

return err
},
msg,
)
t.d.OnMessage(c, msg)
}

func (t *testGateway) Start(_ context.Context, _ kit.GatewayStartConfig) error {
Expand Down Expand Up @@ -241,6 +228,16 @@ func (t testConn) Write(data []byte) (int, error) {
return t.buf.Write(data)
}

func (t *testConn) WriteEnvelope(e *kit.Envelope) error {
b, err := kit.MarshalMessage(e.GetMsg())
if err != nil {
return err
}
_, err = t.buf.Write(b)

return err
}

func (t testConn) Read() ([]byte, error) {
return io.ReadAll(t.buf)
}
Expand Down
2 changes: 1 addition & 1 deletion kit/envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (e *Envelope) Send() {
}

// Use WriteFunc to write the Envelope into the connection
e.ctx.Error(e.ctx.wf(e.conn, e))
e.ctx.Error(e.conn.WriteEnvelope(e))

// Release the envelope
e.release()
Expand Down
34 changes: 7 additions & 27 deletions std/gateways/fasthttp/bundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (b *bundle) genHTTPHandler(rd routeData) fasthttp.RequestHandler {
c.ctx = ctx
c.rd = &rd
b.d.OnOpen(c)
b.d.OnMessage(c, b.writeFunc, ctx.PostBody())
b.d.OnMessage(c, ctx.PostBody())
b.d.OnClose(c.ConnID())

b.connPool.Put(c)
Expand All @@ -227,10 +227,11 @@ func (b *bundle) wsHandler(ctx *fasthttp.RequestCtx) {
ctx,
func(conn *websocket.Conn) {
wsc := &wsConn{
kv: map[string]string{},
id: atomic.AddUint64(&b.wsNextID, 1),
clientIP: realip.FromRequest(ctx),
c: conn,
kv: map[string]string{},
id: atomic.AddUint64(&b.wsNextID, 1),
clientIP: realip.FromRequest(ctx),
c: conn,
rpcOutFactory: b.rpcOutFactory,
}
b.d.OnOpen(wsc)
for {
Expand All @@ -249,7 +250,7 @@ func (b *bundle) wsHandler(ctx *fasthttp.RequestCtx) {
}

func (b *bundle) wsHandlerExec(buf *buf.Bytes, wsc *wsConn) {
b.d.OnMessage(wsc, b.writeFunc, *buf.Bytes())
b.d.OnMessage(wsc, *buf.Bytes())
buf.Release()
}

Expand Down Expand Up @@ -335,28 +336,7 @@ func (b *bundle) writeFunc(conn kit.Conn, e *kit.Envelope) error {
outC.Release()

return err
case *httpConn:
var (
data []byte
err error
)

data, err = kit.MarshalMessage(e.GetMsg())
if err != nil {
return err
}

e.WalkHdr(
func(key string, val string) bool {
c.ctx.Response.Header.Set(key, val)

return true
},
)

c.ctx.SetBody(data)

return nil
default:
panic("BUG!! incorrect connection")
}
Expand Down
24 changes: 24 additions & 0 deletions std/gateways/fasthttp/conn_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,30 @@ func (c *httpConn) Write(data []byte) (int, error) {
return len(data), nil
}

func (c *httpConn) WriteEnvelope(e *kit.Envelope) error {
var (
data []byte
err error
)

data, err = kit.MarshalMessage(e.GetMsg())
if err != nil {
return err
}

e.WalkHdr(
func(key string, val string) bool {
c.ctx.Response.Header.Set(key, val)

return true
},
)

c.ctx.SetBody(data)

return nil
}

func (c *httpConn) Stream() bool {
return false
}
Expand Down
32 changes: 28 additions & 4 deletions std/gateways/fasthttp/conn_ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (

type wsConn struct {
utils.SpinLock
kv map[string]string
id uint64
clientIP string
c *websocket.Conn
kv map[string]string
id uint64
clientIP string
c *websocket.Conn
rpcOutFactory kit.OutgoingRPCFactory
}

var _ kit.Conn = (*wsConn)(nil)
Expand Down Expand Up @@ -47,6 +48,29 @@ func (w *wsConn) Write(data []byte) (int, error) {
return len(data), nil
}

func (w *wsConn) WriteEnvelope(e *kit.Envelope) error {
outC := w.rpcOutFactory()
outC.InjectMessage(e.GetMsg())
outC.SetID(e.GetID())
e.WalkHdr(
func(key string, val string) bool {
outC.SetHdr(key, val)

return true
},
)

data, err := outC.Marshal()
if err != nil {
return err
}

_, err = w.Write(data)
outC.Release()

return err
}

func (w *wsConn) Stream() bool {
return true
}
Expand Down
42 changes: 36 additions & 6 deletions std/gateways/fastws/conn.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package fastws

import (
"github.com/clubpay/ronykit/kit"
"github.com/clubpay/ronykit/kit/errors"
"github.com/clubpay/ronykit/kit/utils"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
Expand All @@ -16,15 +18,22 @@ type wsConn struct {
r *wsutil.Reader
w *wsutil.Writer
handshakeDone bool
rpcOutFactory kit.OutgoingRPCFactory
}

func newWebsocketConn(id uint64, c gnet.Conn) *wsConn {
var _ kit.Conn = (*wsConn)(nil)

func newWebsocketConn(
id uint64, c gnet.Conn,
rpcOutFactory kit.OutgoingRPCFactory,
) *wsConn {
wsc := &wsConn{
w: wsutil.NewWriter(c, ws.StateServerSide, ws.OpText),
r: wsutil.NewReader(c, ws.StateServerSide),
id: id,
kv: map[string]string{},
c: c,
w: wsutil.NewWriter(c, ws.StateServerSide, ws.OpText),
r: wsutil.NewReader(c, ws.StateServerSide),
id: id,
kv: map[string]string{},
c: c,
rpcOutFactory: rpcOutFactory,
}

return wsc
Expand Down Expand Up @@ -61,6 +70,27 @@ func (c *wsConn) Write(data []byte) (int, error) {
return n, err
}

func (c *wsConn) WriteEnvelope(e *kit.Envelope) error {
outC := c.rpcOutFactory()
outC.InjectMessage(e.GetMsg())
outC.SetID(e.GetID())
e.WalkHdr(func(key string, val string) bool {
outC.SetHdr(key, val)

return true
})

data, err := outC.Marshal()
if err != nil {
return errors.Wrap(kit.ErrEncodeOutgoingMessageFailed, err)
}

_, err = c.Write(data)
outC.Release()

return err
}

func (c *wsConn) Stream() bool {
return true
}
Expand Down
Loading

0 comments on commit 874a8ed

Please sign in to comment.