Skip to content

Commit

Permalink
feat: zipper support middleware (#660)
Browse files Browse the repository at this point in the history
# Description

Adding `WithConnMiddleware()` and `WithFrameMiddleware()` zipper option,
their type is below:

```go
type (
	// FrameHandler handles a frame.
	FrameHandler func(*Context)
	// FrameMiddleware is a middleware for frame handler.
	FrameMiddleware func(FrameHandler) FrameHandler
)

type (
	// ConnHandler handles a connection and route.
	ConnHandler func(*Connection, router.Route)
	// ConnMiddleware is a middleware for connection handler.
	ConnMiddleware func(ConnHandler) ConnHandler
)
```

---------

Co-authored-by: venjiang <venjiang@gmail.com>
Co-authored-by: C.C <fanweixiao@gmail.com>
  • Loading branch information
3 people authored Oct 30, 2023
1 parent 92b21e7 commit 1754f26
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 248 deletions.
56 changes: 31 additions & 25 deletions core/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,24 +45,23 @@ func TestFrameRoundTrip(t *testing.T) {
backflow = []byte("hello backflow frame")
)

// test server hooks
ht := &hookTester{t: t}

server := NewServer("zipper",
WithAuth("token", "auth-token"),
WithServerQuicConfig(DefalutQuicConfig),
WithServerTLSConfig(nil),
WithServerLogger(discardingLogger),
WithConnMiddleware(ht.connMiddleware),
WithFrameMiddleware(ht.frameMiddleware),
)
server.ConfigRouter(router.Default())

// test server hooks
ht := &hookTester{t}
server.SetStartHandlers(ht.startHandler)
server.SetBeforeHandlers(ht.beforeHandler)
server.SetAfterHandlers(ht.afterHandler)

recorder := newFrameWriterRecorder("mockID", "mockClientLocal", "mockClientRemote")
server.AddDownstreamServer(recorder)

assert.Equal(t, server.Downstreams()["mockID"], recorder.ID())
assert.Equal(t, server.Downstreams()["mockClientLocal"], recorder.ID())

go func() {
err := server.ListenAndServe(ctx, testaddr)
Expand Down Expand Up @@ -191,29 +190,36 @@ func checkClientExited(client *Client, tim time.Duration) bool {
}

type hookTester struct {
t *testing.T
}

func (a *hookTester) startHandler(ctx *Context) error {
ctx.Set("start", "yes")
return nil
mu sync.Mutex
connNames []string
t *testing.T
}

func (a *hookTester) beforeHandler(ctx *Context) error {
ctx.Set("before", "ok")
return nil
}
func (a *hookTester) connMiddleware(next ConnHandler) ConnHandler {
return func(c *Connection, r router.Route) {
a.mu.Lock()
if a.connNames == nil {
a.connNames = make([]string, 0)
}
a.connNames = append(a.connNames, c.Name())
a.mu.Unlock()

func (a *hookTester) afterHandler(ctx *Context) error {
v, ok := ctx.Get("start")
assert.True(a.t, ok)
assert.Equal(a.t, v, "yes")
next(c, r)

v = ctx.Value("before")
assert.True(a.t, ok)
assert.Equal(a.t, v, "ok")
a.mu.Lock()
assert.Contains(a.t, a.connNames, c.Name())
a.mu.Unlock()
}
}

return nil
func (a *hookTester) frameMiddleware(next FrameHandler) FrameHandler {
return func(c *Context) {
c.Set("a", "b")
next(c)
v, ok := c.Get("a")
assert.True(a.t, ok)
assert.Equal(a.t, "b", v)
}
}

func createTestStreamFunction(name string, zipperAddr string, observedTag frame.Tag) *Client {
Expand Down
57 changes: 34 additions & 23 deletions core/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/quic-go/quic-go"
"github.com/yomorun/yomo/core/frame"
"github.com/yomorun/yomo/core/metadata"
"golang.org/x/exp/slog"
)

// ConnectionInfo holds the information of connection.
Expand All @@ -22,77 +23,87 @@ type ConnectionInfo interface {
ObserveDataTags() []frame.Tag
}

// Connection wraps conneciton and stream to transfer frames.
// Connection be used to read and write frames, and be managed by Connector.
type Connection interface {
Context() context.Context
ConnectionInfo
frame.ReadWriteCloser
// CloseWithError closes the connection with an error string.
CloseWithError(string) error
}

type connection struct {
// Connection wraps connection and stream for transmitting frames, it can be
// used for reading and writing frames, and is managed by the Connector.
type Connection struct {
name string
id string
clientType ClientType
metadata metadata.M
observeDataTags []uint32
conn quic.Connection
fs *FrameStream
Logger *slog.Logger
}

func newConnection(
name string, id string, clientType ClientType, md metadata.M, tags []uint32,
conn quic.Connection, fs *FrameStream) *connection {
return &connection{
conn quic.Connection, fs *FrameStream, logger *slog.Logger) *Connection {

logger = logger.With("conn_id", id, "conn_name", name)
if conn != nil {
logger.Info("new client connected", "remote_addr", conn.RemoteAddr().String(), "client_type", clientType.String())
}

return &Connection{
name: name,
id: id,
clientType: clientType,
metadata: md,
observeDataTags: tags,
conn: conn,
fs: fs,
Logger: logger,
}
}

func (c *connection) Close() error {
// Close closes the connection.
func (c *Connection) Close() error {
return c.fs.Close()
}

func (c *connection) Context() context.Context {
// Context returns the context of the connection.
func (c *Connection) Context() context.Context {
return c.fs.Context()
}

func (c *connection) ID() string {
// ID returns the connection ID.
func (c *Connection) ID() string {
return c.id
}

func (c *connection) Metadata() metadata.M {
// Metadata returns the extra info of the application.
func (c *Connection) Metadata() metadata.M {
return c.metadata
}

func (c *connection) Name() string {
// Name returns the name of the connection
func (c *Connection) Name() string {
return c.name
}

func (c *connection) ObserveDataTags() []uint32 {
// ObserveDataTags returns the observed data tags.
func (c *Connection) ObserveDataTags() []uint32 {
return c.observeDataTags
}

func (c *connection) ReadFrame() (frame.Frame, error) {
// ReadFrame reads a frame from the connection.
func (c *Connection) ReadFrame() (frame.Frame, error) {
return c.fs.ReadFrame()
}

func (c *connection) ClientType() ClientType {
// ClientType returns the client type of the connection.
func (c *Connection) ClientType() ClientType {
return c.clientType
}

func (c *connection) WriteFrame(f frame.Frame) error {
// WriteFrame writes a frame to the connection.
func (c *Connection) WriteFrame(f frame.Frame) error {
return c.fs.WriteFrame(f)
}

func (c *connection) CloseWithError(errString string) error {
// CloseWithError closes the connection with error.
func (c *Connection) CloseWithError(errString string) error {
return c.conn.CloseWithError(YomoCloseErrorCode, errString)
}

Expand Down
3 changes: 2 additions & 1 deletion core/connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo/core/frame"
"github.com/yomorun/yomo/core/metadata"
"golang.org/x/exp/slog"
)

func TestConnection(t *testing.T) {
Expand All @@ -30,7 +31,7 @@ func TestConnection(t *testing.T) {
// create frame connection.
fs := NewFrameStream(mockStream, &byteCodec{}, &bytePacketReadWriter{})

connection := newConnection(name, id, styp, md, observed, nil, fs)
connection := newConnection(name, id, styp, md, observed, nil, fs, slog.Default())

t.Run("ConnectionInfo", func(t *testing.T) {
assert.Equal(t, id, connection.ID())
Expand Down
26 changes: 12 additions & 14 deletions core/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func NewConnector(ctx context.Context) *Connector {
// Store stores Connection to Connector,
// If the connID is the same twice, the new connection will replace the old connection.
// If Connector be closed, The function will return ErrConnectorClosed.
func (c *Connector) Store(connID string, conn Connection) error {
func (c *Connector) Store(connID string, conn *Connection) error {
select {
case <-c.ctx.Done():
return ErrConnectorClosed
Expand Down Expand Up @@ -62,7 +62,7 @@ func (c *Connector) Remove(connID string) error {
// Get retrieves the Connection with the specified connID.
// If the Connector does not have a connection with the given connID, return nil and false.
// If Connector be closed, The function will return ErrConnectorClosed.
func (c *Connector) Get(connID string) (Connection, bool, error) {
func (c *Connector) Get(connID string) (*Connection, bool, error) {
select {
case <-c.ctx.Done():
return nil, false, ErrConnectorClosed
Expand All @@ -74,29 +74,27 @@ func (c *Connector) Get(connID string) (Connection, bool, error) {
return nil, false, nil
}

connection := v.(Connection)

return connection, true, nil
return v.(*Connection), true, nil
}

// FindConnectionFunc is used to search for a specific connection within the Connector.
type FindConnectionFunc func(ConnectionInfo) bool

// Find searches a stream collection using the specified find function.
// If Connector be closed, The function will return ErrConnectorClosed.
func (c *Connector) Find(findFunc FindConnectionFunc) ([]Connection, error) {
func (c *Connector) Find(findFunc FindConnectionFunc) ([]*Connection, error) {
select {
case <-c.ctx.Done():
return []Connection{}, ErrConnectorClosed
return []*Connection{}, ErrConnectorClosed
default:
}

connections := make([]Connection, 0)
connections := make([]*Connection, 0)
c.connections.Range(func(key interface{}, val interface{}) bool {
stream := val.(Connection)
conn := val.(*Connection)

if findFunc(stream) {
connections = append(connections, stream)
if findFunc(conn) {
connections = append(connections, conn)
}
return true
})
Expand All @@ -112,10 +110,10 @@ func (c *Connector) Snapshot() map[string]string {

c.connections.Range(func(key interface{}, val interface{}) bool {
var (
streamID = key.(string)
stream = val.(Connection)
connID = key.(string)
conn = val.(*Connection)
)
result[streamID] = stream.Name()
result[connID] = conn.Name()
return true
})

Expand Down
11 changes: 5 additions & 6 deletions core/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

"github.com/stretchr/testify/assert"
"github.com/yomorun/yomo/core/frame"
"golang.org/x/exp/slog"
)

func TestConnector(t *testing.T) {
Expand Down Expand Up @@ -37,10 +38,9 @@ func TestConnector(t *testing.T) {
err = connector.Remove(connID)
assert.NoError(t, err)

gotStream, ok, err := connector.Get(connID)
_, ok, err := connector.Get(connID)
assert.NoError(t, err)
assert.False(t, ok)
assert.Equal(t, gotStream, nil)
})

t.Run("Find", func(t *testing.T) {
Expand Down Expand Up @@ -92,9 +92,8 @@ func TestConnector(t *testing.T) {

t.Run("Get", func(t *testing.T) {
conn1 := mockConn("id-1", "name-1")
gotStream, ok, err := connector.Get(conn1.ID())
_, ok, err := connector.Get(conn1.ID())
assert.False(t, ok)
assert.Equal(t, gotStream, nil)
assert.ErrorIs(t, err, ErrConnectorClosed)
})

Expand All @@ -117,6 +116,6 @@ func TestConnector(t *testing.T) {

// mockConn returns a connection that only includes an ID and a name.
// This function is used for unit testing purposes.
func mockConn(id, name string) Connection {
return newConnection(name, id, ClientType(0), nil, []frame.Tag{0}, nil, nil)
func mockConn(id, name string) *Connection {
return newConnection(name, id, ClientType(0), nil, []frame.Tag{0}, nil, nil, slog.Default())
}
Loading

1 comment on commit 1754f26

@vercel
Copy link

@vercel vercel bot commented on 1754f26 Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Successfully deployed to the following URLs:

yomo – ./

yomo-yomorun.vercel.app
yomo.vercel.app
www.yomo.run
yomo.run
yomo-git-master-yomorun.vercel.app

Please sign in to comment.