Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc: support inject http headers using outgoing context #26023

Merged
merged 14 commits into from
Nov 16, 2022
56 changes: 56 additions & 0 deletions rpc/context_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright 2022 The go-ethereum Authors
// This file is part of the go-ethereum library.
//
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.

package rpc

import (
"context"
"net/http"
)

type mdHeaderKey struct{}

// NewContextWithHeaders wraps the given context, adding HTTP headers. These headers will
// be applied by Client when making a request using the returned context.
func NewContextWithHeaders(ctx context.Context, h http.Header) context.Context {
if len(h) == 0 {
// This check ensures the header map set in context will never be nil.
return ctx
}

var ctxh http.Header
prev, ok := ctx.Value(mdHeaderKey{}).(http.Header)
if ok {
ctxh = setHeaders(prev.Clone(), h)
} else {
ctxh = h.Clone()
}
return context.WithValue(ctx, mdHeaderKey{}, ctxh)
}

// headersFromContext is used to extract http.Header from context.
func headersFromContext(ctx context.Context) http.Header {
source, _ := ctx.Value(mdHeaderKey{}).(http.Header)
return source
}

// setHeaders sets all headers from src in dst.
func setHeaders(dst http.Header, src http.Header) http.Header {
for key, values := range src {
dst[http.CanonicalHeaderKey(key)] = values
}
return dst
}
2 changes: 2 additions & 0 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
hc.mu.Lock()
req.Header = hc.headers.Clone()
hc.mu.Unlock()
setHeaders(req.Header, headersFromContext(ctx))

if hc.auth != nil {
if err := hc.auth(req.Header); err != nil {
return nil, err
Expand Down
42 changes: 42 additions & 0 deletions rpc/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package rpc

import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -198,3 +200,43 @@ func TestHTTPPeerInfo(t *testing.T) {
t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent)
}
}

func TestNewContextWithHeaders(t *testing.T) {
expectedHeaders := 0
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
for i := 0; i < expectedHeaders; i++ {
key, want := fmt.Sprintf("key-%d", i), fmt.Sprintf("val-%d", i)
if have := request.Header.Get(key); have != want {
t.Errorf("wrong request headers for %s, want: %s, have: %s", key, want, have)
}
}
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte(`{}`))
}))
defer server.Close()

client, err := Dial(server.URL)
if err != nil {
t.Fatalf("failed to dial: %s", err)
}
defer client.Close()

newHdr := func(k, v string) http.Header {
header := http.Header{}
header.Set(k, v)
return header
}
ctx1 := NewContextWithHeaders(context.Background(), newHdr("key-0", "val-0"))
ctx2 := NewContextWithHeaders(ctx1, newHdr("key-1", "val-1"))
ctx3 := NewContextWithHeaders(ctx2, newHdr("key-2", "val-2"))

expectedHeaders = 3
if err := client.CallContext(ctx3, nil, "test"); err != ErrNoResult {
t.Error("call failed", err)
}

expectedHeaders = 2
if err := client.CallContext(ctx2, nil, "test"); err != ErrNoResult {
t.Error("call failed:", err)
}
storyicon marked this conversation as resolved.
Show resolved Hide resolved
}