Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rpc: support inject http headers using outgoing context #26023

Merged
merged 14 commits into from
Nov 16, 2022
2 changes: 2 additions & 0 deletions rpc/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ func (hc *httpConn) doRequest(ctx context.Context, msg interface{}) (io.ReadClos
hc.mu.Lock()
req.Header = hc.headers.Clone()
hc.mu.Unlock()
mergeHeadersFromOutgoingContext(ctx, req.Header)

if hc.auth != nil {
if err := hc.auth(req.Header); err != nil {
return nil, err
Expand Down
27 changes: 27 additions & 0 deletions rpc/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package rpc

import (
"context"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -198,3 +199,29 @@ func TestHTTPPeerInfo(t *testing.T) {
t.Errorf("wrong HTTP.Origin %q", info.HTTP.UserAgent)
}
}

func TestNewOutgoingContext(t *testing.T) {
const (
testHeaderKey = "test-header-key"
testHeaderValue = "test-header-value"
)
server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if got := request.Header.Get(testHeaderKey); got != testHeaderValue {
t.Errorf("wrong request headers for %s, expected: %s, actual: %s", testHeaderKey, testHeaderValue, got)
}
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte(`{}`))
}))
defer server.Close()
client, err := Dial(server.URL)
if err != nil {
panic(err)
storyicon marked this conversation as resolved.
Show resolved Hide resolved
}
defer client.Close()
header := http.Header{}
header.Set(testHeaderKey, testHeaderValue)
ctx := NewOutgoingContext(context.TODO(), header)
if err := client.CallContext(ctx, &struct{}{}, "test"); err != ErrNoResult {
t.Errorf("failed to call context: %s", err)
}
storyicon marked this conversation as resolved.
Show resolved Hide resolved
}
42 changes: 42 additions & 0 deletions rpc/metadata.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package rpc

import (
"context"
"net/http"
)

type rawMd struct {
headers http.Header
}
storyicon marked this conversation as resolved.
Show resolved Hide resolved

type mdOutgoingKey struct{}

// NewOutgoingContext is used to attach http headers into the context
func NewOutgoingContext(ctx context.Context, header http.Header) context.Context {
storyicon marked this conversation as resolved.
Show resolved Hide resolved
return context.WithValue(ctx, mdOutgoingKey{}, rawMd{headers: header})
}

// HeadersFromOutgoingContext is used to extract http headers from the context
func HeadersFromOutgoingContext(ctx context.Context) (http.Header, bool) {
storyicon marked this conversation as resolved.
Show resolved Hide resolved
value := ctx.Value(mdOutgoingKey{})
if value == nil {
return nil, false
}
headers := value.(rawMd).headers
if headers == nil {
return nil, false
}
return headers, true
}

// mergeHeadersFromOutgoingContext is used to extract http headers from the context and inject it into the provided headers
func mergeHeadersFromOutgoingContext(ctx context.Context, headers http.Header) {
storyicon marked this conversation as resolved.
Show resolved Hide resolved
if kvs, ok := HeadersFromOutgoingContext(ctx); ok {
for key, values := range kvs {
headers.Del(key)
for _, val := range values {
headers.Add(key, val)
}
}
}
}