Skip to content

Commit

Permalink
Use WithUserAgent
Browse files Browse the repository at this point in the history
Signed-off-by: Julien Pivotto <roidelapluie@o11y.eu>
  • Loading branch information
roidelapluie committed Jun 30, 2022
1 parent 99a1aca commit 316097c
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 28 deletions.
24 changes: 18 additions & 6 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,6 @@ type OAuth2 struct {
ProxyURL URL `yaml:"proxy_url,omitempty" json:"proxy_url,omitempty"`
// TLSConfig is used to connect to the token URL.
TLSConfig TLSConfig `yaml:"tls_config,omitempty"`
// UserAgent is used to set a custom User-Agent http header while making the oauth request.
UserAgent string `yaml:"user_agent,omitempty" json:"user_agent,omitempty"`
}

// SetDirectory joins any relative file paths with dir.
Expand Down Expand Up @@ -374,6 +372,7 @@ type httpClientOptions struct {
keepAlivesEnabled bool
http2Enabled bool
idleConnTimeout time.Duration
userAgent string
}

// HTTPClientOption defines an option that can be applied to the HTTP client.
Expand Down Expand Up @@ -407,6 +406,13 @@ func WithIdleConnTimeout(timeout time.Duration) HTTPClientOption {
}
}

// WithIdleConnTimeout allows setting the user agent.
func WithUserAgent(ua string) HTTPClientOption {
return func(opts *httpClientOptions) {
opts.userAgent = ua
}
}

// NewClient returns a http.Client using the specified http.RoundTripper.
func newClient(rt http.RoundTripper) *http.Client {
return &http.Client{Transport: rt}
Expand Down Expand Up @@ -499,8 +505,12 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
rt = NewBasicAuthRoundTripper(cfg.BasicAuth.Username, cfg.BasicAuth.Password, cfg.BasicAuth.PasswordFile, rt)
}

if opts.userAgent != "" {
rt = NewUserAgentRoundTripper(opts.userAgent, rt)
}

if cfg.OAuth2 != nil {
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt)
rt = NewOAuth2RoundTripper(cfg.OAuth2, rt, &opts)
}
// Return a new configured RoundTripper.
return rt, nil
Expand Down Expand Up @@ -621,12 +631,14 @@ type oauth2RoundTripper struct {
next http.RoundTripper
secret string
mtx sync.RWMutex
opts *httpClientOptions
}

func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper) http.RoundTripper {
func NewOAuth2RoundTripper(config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
return &oauth2RoundTripper{
config: config,
next: next,
opts: opts,
}
}

Expand Down Expand Up @@ -683,8 +695,8 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
}
}

if rt.config.UserAgent != "" {
t = NewUserAgentRoundTripper(rt.config.UserAgent, t)
if rt.opts.userAgent != "" {
t = NewUserAgentRoundTripper(rt.opts.userAgent, t)
}

ctx := context.WithValue(context.Background(), oauth2.HTTPClient, &http.Client{Transport: t})
Expand Down
63 changes: 46 additions & 17 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1183,12 +1183,6 @@ type oauth2TestServerResponse struct {

func TestOAuth2(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/token" {
if r.Header.Get("User-Agent") != "myuseragent" {
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
}
}

res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
Expand All @@ -1205,7 +1199,6 @@ scopes:
- A
- B
token_url: %s/token
user_agent: myuseragent
endpoint_params:
hi: hello
`, ts.URL)
Expand All @@ -1215,7 +1208,6 @@ endpoint_params:
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: fmt.Sprintf("%s/token", ts.URL),
UserAgent: "myuseragent",
}

var unmarshalledConfig OAuth2
Expand All @@ -1227,7 +1219,7 @@ endpoint_params:
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
}

rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions)

client := http.Client{
Transport: rt,
Expand All @@ -1240,6 +1232,50 @@ endpoint_params:
}
}

func TestOAuth2UserAgent(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/" {
if r.Header.Get("User-Agent") != "myuseragent" {
t.Fatalf("Expected User-Agent header in oauth request to be 'myuseragent', got '%s'", r.Header.Get("User-Agent"))
}
}

res, _ := json.Marshal(oauth2TestServerResponse{
AccessToken: "12345",
TokenType: "Bearer",
})
w.Header().Add("Content-Type", "application/json")
_, _ = w.Write(res)
}))
defer ts.Close()

config := &OAuth2{
ClientID: "1",
ClientSecret: "2",
Scopes: []string{"A", "B"},
EndpointParams: map[string]string{"hi": "hello"},
TokenURL: fmt.Sprintf("%s/token", ts.URL),
}

opts := defaultHTTPClientOptions
WithUserAgent("myuseragent")(&opts)

rt := NewOAuth2RoundTripper(config, http.DefaultTransport, &opts)

client := http.Client{
Transport: rt,
}
resp, err := client.Get(ts.URL)
if err != nil {
t.Fatal(err)
}

authorization := resp.Request.Header.Get("Authorization")
if authorization != "Bearer 12345" {
t.Fatalf("Expected authorization header to be 'Bearer 12345', got '%s'", authorization)
}
}

func TestOAuth2WithFile(t *testing.T) {
var expectedAuth *string
var previousAuth string
Expand Down Expand Up @@ -1302,7 +1338,7 @@ endpoint_params:
t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig)
}

rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport)
rt := NewOAuth2RoundTripper(&expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions)

client := http.Client{
Transport: rt,
Expand Down Expand Up @@ -1496,10 +1532,3 @@ func TestOAuth2Proxy(t *testing.T) {
t.Errorf("Error loading OAuth2 client config: %v", err)
}
}

func TestOAuth2UserAgent(t *testing.T) {
_, _, err := LoadHTTPConfigFile("testdata/http.conf.oauth2-user-agent.good.yml")
if err != nil {
t.Errorf("Error loading OAuth2 client config: %v", err)
}
}
5 changes: 0 additions & 5 deletions config/testdata/http.conf.oauth2-user-agent.good.yml

This file was deleted.

0 comments on commit 316097c

Please sign in to comment.