Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix staticcheck warning SA1029: inappropriate key in call to context.WithValue #693

Merged
merged 2 commits into from
Oct 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions p2p/simulations/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ func (s *Server) wrapHandler(handler http.HandlerFunc) httprouter.Handle {
http.NotFound(w, req)
return
}
//lint:ignore SA1029 This file will be removed later, reference: #30250
ctx = context.WithValue(ctx, "node", node)
}

Expand All @@ -735,6 +736,7 @@ func (s *Server) wrapHandler(handler http.HandlerFunc) httprouter.Handle {
http.NotFound(w, req)
return
}
//lint:ignore SA1029 This file will be removed later, reference: #30250
ctx = context.WithValue(ctx, "peer", peer)
}

Expand Down
5 changes: 3 additions & 2 deletions rpc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ type clientConn struct {

func (c *Client) newClientConn(conn ServerCodec) *clientConn {
ctx := context.WithValue(context.Background(), clientContextKey{}, c)
ctx = context.WithValue(ctx, peerInfoContextKey{}, conn.peerInfo())
gzliudan marked this conversation as resolved.
Show resolved Hide resolved
handler := newHandler(ctx, conn, c.idgen, c.services)
return &clientConn{conn, handler}
}
Expand Down Expand Up @@ -473,7 +474,7 @@ func (c *Client) Subscribe(ctx context.Context, namespace string, channel interf
// Check type of channel first.
chanVal := reflect.ValueOf(channel)
if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 {
panic("first argument to Subscribe must be a writable channel")
panic(fmt.Sprintf("channel argument of Subscribe has type %T, need writable channel", channel))
}
if chanVal.IsNil() {
panic("channel given to Subscribe must not be nil")
Expand Down Expand Up @@ -532,8 +533,8 @@ func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error
}

func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
// The previous write failed. Try to establish a new connection.
if c.writeConn == nil {
// The previous write failed. Try to establish a new connection.
if err := c.reconnect(ctx); err != nil {
return err
}
Expand Down
30 changes: 18 additions & 12 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,18 @@ type httpConn struct {
headers http.Header
}

// httpConn is treated specially by Client.
// httpConn implements ServerCodec, but it is treated specially by Client
// and some methods don't work. The panic() stubs here exist to ensure
// this special treatment is correct.

func (hc *httpConn) writeJSON(context.Context, interface{}) error {
panic("writeJSON called on httpConn")
}

func (hc *httpConn) peerInfo() PeerInfo {
panic("peerInfo called on httpConn")
}

func (hc *httpConn) remoteAddr() string {
return hc.url
}
Expand Down Expand Up @@ -258,20 +265,19 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), code)
return
}

// Create request-scoped context.
connInfo := PeerInfo{Transport: "http", RemoteAddr: r.RemoteAddr}
connInfo.HTTP.Version = r.Proto
connInfo.HTTP.Host = r.Host
connInfo.HTTP.Origin = r.Header.Get("Origin")
connInfo.HTTP.UserAgent = r.Header.Get("User-Agent")
ctx := r.Context()
ctx = context.WithValue(ctx, peerInfoContextKey{}, connInfo)

// All checks passed, create a codec that reads directly from the request body
// until EOF, writes the response to w, and orders the server to process a
// single request.
ctx := r.Context()
ctx = context.WithValue(ctx, "remote", r.RemoteAddr)
ctx = context.WithValue(ctx, "scheme", r.Proto)
ctx = context.WithValue(ctx, "local", r.Host)
if ua := r.Header.Get("User-Agent"); ua != "" {
ctx = context.WithValue(ctx, "User-Agent", ua)
}
if origin := r.Header.Get("Origin"); origin != "" {
ctx = context.WithValue(ctx, "Origin", origin)
}

w.Header().Set("content-type", contentType)
codec := newHTTPServerConn(r, w)
defer codec.close()
Expand Down
36 changes: 36 additions & 0 deletions rpc/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,39 @@ func TestHTTPRespBodyUnlimited(t *testing.T) {
t.Fatalf("response has wrong length %d, want %d", len(r), respLength)
}
}

func TestHTTPPeerInfo(t *testing.T) {
s := newTestServer()
defer s.Stop()
ts := httptest.NewServer(s)
defer ts.Close()

c, err := Dial(ts.URL)
if err != nil {
t.Fatal(err)
}
c.SetHeader("user-agent", "ua-testing")
c.SetHeader("origin", "origin.example.com")

// Request peer information.
var info PeerInfo
if err := c.Call(&info, "test_peerInfo"); err != nil {
t.Fatal(err)
}

if info.RemoteAddr == "" {
t.Error("RemoteAddr not set")
}
if info.Transport != "http" {
t.Errorf("wrong Transport %q", info.Transport)
}
if info.HTTP.Version != "HTTP/1.1" {
t.Errorf("wrong HTTP.Version %q", info.HTTP.Version)
}
if info.HTTP.UserAgent != "ua-testing" {
t.Errorf("wrong HTTP.UserAgent %q", info.HTTP.UserAgent)
}
if info.HTTP.Origin != "origin.example.com" {
t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent)
}
}
5 changes: 5 additions & 0 deletions rpc/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,11 @@ func NewCodec(conn Conn) ServerCodec {
return NewFuncCodec(conn, enc.Encode, dec.Decode)
}

func (c *jsonCodec) peerInfo() PeerInfo {
// This returns "ipc" because all other built-in transports have a separate codec type.
return PeerInfo{Transport: "ipc", RemoteAddr: c.remote}
}

func (c *jsonCodec) remoteAddr() string {
return c.remote
}
Expand Down
35 changes: 35 additions & 0 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,38 @@ func (s *RPCService) Modules() map[string]string {
}
return modules
}

// PeerInfo contains information about the remote end of the network connection.
//
// This is available within RPC method handlers through the context. Call
// PeerInfoFromContext to get information about the client connection related to
// the current method call.
type PeerInfo struct {
// Transport is name of the protocol used by the client.
// This can be "http", "ws" or "ipc".
Transport string

// Address of client. This will usually contain the IP address and port.
RemoteAddr string

// Addditional information for HTTP and WebSocket connections.
HTTP struct {
// Protocol version, i.e. "HTTP/1.1". This is not set for WebSocket.
Version string
// Header values sent by the client.
UserAgent string
Origin string
Host string
}
}

type peerInfoContextKey struct{}

// PeerInfoFromContext returns information about the client's network connection.
// Use this with the context passed to RPC method handler functions.
//
// The zero value is returned if no connection info is present in ctx.
func PeerInfoFromContext(ctx context.Context) PeerInfo {
info, _ := ctx.Value(peerInfoContextKey{}).(PeerInfo)
return info
}
2 changes: 1 addition & 1 deletion rpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestServerRegisterName(t *testing.T) {
t.Fatalf("Expected service calc to be registered")
}

wantCallbacks := 10
wantCallbacks := 11
if len(svc.callbacks) != wantCallbacks {
t.Errorf("Expected %d callbacks for service 'service', got %d", wantCallbacks, len(svc.callbacks))
}
Expand Down
4 changes: 4 additions & 0 deletions rpc/testservice_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ func (s *testService) EchoWithCtx(ctx context.Context, str string, i int, args *
return echoResult{str, i, args}
}

func (s *testService) PeerInfo(ctx context.Context) PeerInfo {
return PeerInfoFromContext(ctx)
}

func (s *testService) Sleep(ctx context.Context, duration time.Duration) {
time.Sleep(duration)
}
Expand Down
2 changes: 2 additions & 0 deletions rpc/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ type DataError interface {
// a RPC session. Implementations must be go-routine safe since the codec can be called in
// multiple go-routines concurrently.
type ServerCodec interface {
peerInfo() PeerInfo
readBatch() (msgs []*jsonrpcMessage, isBatch bool, err error)
close()

jsonWriter
}

Expand Down
20 changes: 17 additions & 3 deletions rpc/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler {
log.Debug("WebSocket upgrade failed", "err", err)
return
}
codec := newWebsocketCodec(conn)
codec := newWebsocketCodec(conn, r.Host, r.Header)
s.ServeCodec(codec, 0)
})
}
Expand Down Expand Up @@ -203,7 +203,7 @@ func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, diale
}
return nil, hErr
}
return newWebsocketCodec(conn), nil
return newWebsocketCodec(conn, endpoint, header), nil
})
}

Expand Down Expand Up @@ -241,18 +241,28 @@ func wsClientHeaders(endpoint, origin string) (string, http.Header, error) {
type websocketCodec struct {
*jsonCodec
conn *websocket.Conn
info PeerInfo

wg sync.WaitGroup
pingReset chan struct{}
}

func newWebsocketCodec(conn *websocket.Conn) ServerCodec {
func newWebsocketCodec(conn *websocket.Conn, host string, req http.Header) ServerCodec {
conn.SetReadLimit(wsMessageSizeLimit)
wc := &websocketCodec{
jsonCodec: NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON).(*jsonCodec),
conn: conn,
pingReset: make(chan struct{}, 1),
info: PeerInfo{
Transport: "ws",
RemoteAddr: conn.RemoteAddr().String(),
},
}
// Fill in connection details.
wc.info.HTTP.Host = host
wc.info.HTTP.Origin = req.Get("Origin")
wc.info.HTTP.UserAgent = req.Get("User-Agent")
// Start pinger.
wc.wg.Add(1)
go wc.pingLoop()
return wc
Expand All @@ -263,6 +273,10 @@ func (wc *websocketCodec) close() {
wc.wg.Wait()
}

func (wc *websocketCodec) peerInfo() PeerInfo {
return wc.info
}

func (wc *websocketCodec) writeJSON(ctx context.Context, v interface{}) error {
err := wc.jsonCodec.writeJSON(ctx, v)
if err == nil {
Expand Down
35 changes: 35 additions & 0 deletions rpc/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,41 @@ func TestWebsocketLargeCall(t *testing.T) {
}
}

func TestWebsocketPeerInfo(t *testing.T) {
var (
s = newTestServer()
ts = httptest.NewServer(s.WebsocketHandler([]string{"origin.example.com"}))
tsurl = "ws:" + strings.TrimPrefix(ts.URL, "http:")
)
defer s.Stop()
defer ts.Close()

ctx := context.Background()
c, err := DialWebsocket(ctx, tsurl, "origin.example.com")
if err != nil {
t.Fatal(err)
}

// Request peer information.
var connInfo PeerInfo
if err := c.Call(&connInfo, "test_peerInfo"); err != nil {
t.Fatal(err)
}

if connInfo.RemoteAddr == "" {
t.Error("RemoteAddr not set")
}
if connInfo.Transport != "ws" {
t.Errorf("wrong Transport %q", connInfo.Transport)
}
if connInfo.HTTP.UserAgent != "Go-http-client/1.1" {
t.Errorf("wrong HTTP.UserAgent %q", connInfo.HTTP.UserAgent)
}
if connInfo.HTTP.Origin != "origin.example.com" {
t.Errorf("wrong HTTP.Origin %q", connInfo.HTTP.UserAgent)
}
}

// This test checks that client handles WebSocket ping frames correctly.
func TestClientWebsocketPing(t *testing.T) {
t.Parallel()
Expand Down