Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add middleware level for incoming messages and add context of request to session #147

Open
wants to merge 4 commits into
base: v2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions peer.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package turnpike

import (
"context"
"fmt"
"time"
)
Expand All @@ -24,6 +25,12 @@ type Peer interface {

// Receive returns a channel of messages coming from the peer.
Receive() <-chan Message

//AddIncomeMiddleware implements preprocess income messages
AddIncomeMiddleware(f func(Message) (Message, error))

//GetContext returns context
GetContext() context.Context
}

// GetMessageTimeout is a convenience function to get a single message from a
Expand Down
13 changes: 13 additions & 0 deletions session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package turnpike

import (
"context"
"fmt"
)

Expand All @@ -27,10 +28,12 @@ func localPipe() (*localPeer, *localPeer) {
a := &localPeer{
incoming: bToA,
outgoing: aToB,
ctx: context.TODO(),
}
b := &localPeer{
incoming: aToB,
outgoing: bToA,
ctx: context.TODO(),
}

return a, b
Expand All @@ -39,6 +42,7 @@ func localPipe() (*localPeer, *localPeer) {
type localPeer struct {
outgoing chan<- Message
incoming <-chan Message
ctx context.Context
}

func (s *localPeer) Receive() <-chan Message {
Expand All @@ -54,3 +58,12 @@ func (s *localPeer) Close() error {
close(s.outgoing)
return nil
}

//todo
func (s *localPeer) AddIncomeMiddleware(f func(Message) (Message, error)) {

}

func (s *localPeer) GetContext() context.Context {
return s.ctx
}
49 changes: 38 additions & 11 deletions websocket.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package turnpike

import (
"context"
"crypto/tls"
"fmt"
"net/http"
Expand All @@ -11,12 +12,14 @@ import (
)

type websocketPeer struct {
conn *websocket.Conn
serializer Serializer
messages chan Message
payloadType int
closed bool
sendMutex sync.Mutex
incomeMiddleware func(Message) (Message, error)
conn *websocket.Conn
serializer Serializer
messages chan Message
payloadType int
closed bool
sendMutex sync.Mutex
ctx context.Context
}

func NewWebsocketPeer(serialization Serialization, url string, requestHeader http.Header, tlscfg *tls.Config, dial DialFunc) (Peer, error) {
Expand Down Expand Up @@ -46,10 +49,12 @@ func newWebsocketPeer(url string, reqHeader http.Header, protocol string, serial
return nil, err
}
ep := &websocketPeer{
conn: conn,
messages: make(chan Message, 10),
serializer: serializer,
payloadType: payloadType,
conn: conn,
messages: make(chan Message, 10),
serializer: serializer,
payloadType: payloadType,
incomeMiddleware: nil,
ctx: context.TODO(),
}
go ep.run()

Expand Down Expand Up @@ -79,6 +84,18 @@ func (ep *websocketPeer) Close() error {
return ep.conn.Close()
}

//AddIncomeMiddleware implements preprocess income messages
func (ep *websocketPeer) AddIncomeMiddleware(f func(Message) (Message, error)) {
ep.sendMutex.Lock()
ep.incomeMiddleware = f
ep.sendMutex.Unlock()
}

//GetContext returns context
func (ep *websocketPeer) GetContext() context.Context {
return ep.ctx
}

func (ep *websocketPeer) run() {
for {
// TODO: use conn.NextMessage() and stream
Expand All @@ -102,7 +119,17 @@ func (ep *websocketPeer) run() {
log.Println("error deserializing peer message:", err)
// TODO: handle error
} else {
ep.messages <- msg
if ep.incomeMiddleware == nil {
ep.messages <- msg
} else {
m, err := ep.incomeMiddleware(msg)
if err != nil {
log.Println(err)
} else {
ep.messages <- m
}
}

}
}
}
Expand Down
6 changes: 4 additions & 2 deletions websocket_server.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package turnpike

import (
"context"
"fmt"
"net/http"

Expand Down Expand Up @@ -108,10 +109,10 @@ func (s *WebsocketServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
s.handleWebsocket(conn)
s.handleWebsocket(conn, r.Context())
}

func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn) {
func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn, ctx context.Context) {
var serializer Serializer
var payloadType int
if proto, ok := s.protocols[conn.Subprotocol()]; ok {
Expand Down Expand Up @@ -139,6 +140,7 @@ func (s *WebsocketServer) handleWebsocket(conn *websocket.Conn) {
serializer: serializer,
messages: make(chan Message, 10),
payloadType: payloadType,
ctx: ctx,
}
go peer.run()

Expand Down