Skip to content

Commit

Permalink
feat: add ConnectionType(ctx) for called methods to use (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
rvagg authored Aug 5, 2024
1 parent 61205be commit cd4bed8
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 8 deletions.
35 changes: 27 additions & 8 deletions rpc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import (
)

func init() {
if err := logging.SetLogLevel("rpc", "DEBUG"); err != nil {
panic(err)
if _, exists := os.LookupEnv("GOLOG_LOG_LEVEL"); !exists {
if err := logging.SetLogLevel("rpc", "DEBUG"); err != nil {
panic(err)
}
}

debugTrace = true
Expand Down Expand Up @@ -497,15 +499,17 @@ func TestParallelRPC(t *testing.T) {
type CtxHandler struct {
lk sync.Mutex

cancelled bool
i int
cancelled bool
i int
connectionType ConnectionType
}

func (h *CtxHandler) Test(ctx context.Context) {
h.lk.Lock()
defer h.lk.Unlock()
timeout := time.After(300 * time.Millisecond)
h.i++
h.connectionType = GetConnectionType(ctx)

select {
case <-timeout:
Expand Down Expand Up @@ -543,6 +547,9 @@ func TestCtx(t *testing.T) {
if !serverHandler.cancelled {
t.Error("expected cancellation on the server side")
}
if serverHandler.connectionType != ConnectionTypeWS {
t.Error("wrong connection type")
}

serverHandler.cancelled = false

Expand All @@ -564,6 +571,9 @@ func TestCtx(t *testing.T) {
if serverHandler.cancelled || serverHandler.i != 2 {
t.Error("wrong serverHandler state")
}
if serverHandler.connectionType != ConnectionTypeWS {
t.Error("wrong connection type")
}

serverHandler.lk.Unlock()
closer()
Expand Down Expand Up @@ -598,6 +608,9 @@ func TestCtxHttp(t *testing.T) {
if !serverHandler.cancelled {
t.Error("expected cancellation on the server side")
}
if serverHandler.connectionType != ConnectionTypeHTTP {
t.Error("wrong connection type")
}

serverHandler.cancelled = false

Expand All @@ -619,6 +632,10 @@ func TestCtxHttp(t *testing.T) {
if serverHandler.cancelled || serverHandler.i != 2 {
t.Error("wrong serverHandler state")
}
// connection type should have switched to WS
if serverHandler.connectionType != ConnectionTypeWS {
t.Error("wrong connection type")
}

serverHandler.lk.Unlock()
closer()
Expand Down Expand Up @@ -1007,10 +1024,12 @@ func TestChanClientReceiveAll(t *testing.T) {
}

func TestControlChanDeadlock(t *testing.T) {
_ = logging.SetLogLevel("rpc", "error")
defer func() {
_ = logging.SetLogLevel("rpc", "debug")
}()
if _, exists := os.LookupEnv("GOLOG_LOG_LEVEL"); !exists {
_ = logging.SetLogLevel("rpc", "error")
defer func() {
_ = logging.SetLogLevel("rpc", "DEBUG")
}()
}

for r := 0; r < 20; r++ {
testControlChanDeadlock(t)
Expand Down
27 changes: 27 additions & 0 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,31 @@ const (
rpcInvalidParams = -32602
)

// ConnectionType indicates the type of connection, this is set in the context and can be retrieved
// with GetConnectionType.
type ConnectionType string

const (
// ConnectionTypeUnknown indicates that the connection type cannot be determined, likely because
// it hasn't passed through an RPCServer.
ConnectionTypeUnknown ConnectionType = "unknown"
// ConnectionTypeHTTP indicates that the connection is an HTTP connection.
ConnectionTypeHTTP ConnectionType = "http"
// ConnectionTypeWS indicates that the connection is a WebSockets connection.
ConnectionTypeWS ConnectionType = "websockets"
)

var connectionTypeCtxKey = &struct{ name string }{"jsonrpc-connection-type"}

// GetConnectionType returns the connection type of the request if it was set by an RPCServer.
// A connection type of ConnectionTypeUnknown means that the connection type was not set.
func GetConnectionType(ctx context.Context) ConnectionType {
if v := ctx.Value(connectionTypeCtxKey); v != nil {
return v.(ConnectionType)
}
return ConnectionTypeUnknown
}

// RPCServer provides a jsonrpc 2.0 http server handler
type RPCServer struct {
*handler
Expand Down Expand Up @@ -97,10 +122,12 @@ func (s *RPCServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {

h := strings.ToLower(r.Header.Get("Connection"))
if strings.Contains(h, "upgrade") {
ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeWS)
s.handleWS(ctx, w, r)
return
}

ctx = context.WithValue(ctx, connectionTypeCtxKey, ConnectionTypeHTTP)
s.handleReader(ctx, r.Body, w, rpcError)
}

Expand Down

0 comments on commit cd4bed8

Please sign in to comment.