From 6b9921f9eba2cd74f2caca0d713bb0a6eb7ef1b9 Mon Sep 17 00:00:00 2001 From: Bartlomiej Plotka Date: Thu, 16 May 2024 15:40:54 +0200 Subject: [PATCH] Refactored oauth2RoundTripper.RoundTrip (#634) * Avoid race condidtion on rt.rt == nil check * Trying to improve readability (less ifs) * Some comment fixes Signed-off-by: bwplotka --- config/http_config.go | 203 +++++++++++++++++++------------------ config/http_config_test.go | 72 ++++++------- 2 files changed, 140 insertions(+), 135 deletions(-) diff --git a/config/http_config.go b/config/http_config.go index 5e9d6507..75e38bb8 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -263,12 +263,12 @@ func (o *OAuth2) UnmarshalJSON(data []byte) error { } // SetDirectory joins any relative file paths with dir. -func (a *OAuth2) SetDirectory(dir string) { - if a == nil { +func (o *OAuth2) SetDirectory(dir string) { + if o == nil { return } - a.ClientSecretFile = JoinDir(dir, a.ClientSecretFile) - a.TLSConfig.SetDirectory(dir) + o.ClientSecretFile = JoinDir(dir, o.ClientSecretFile) + o.TLSConfig.SetDirectory(dir) } // LoadHTTPConfig parses the YAML input s into a HTTPClientConfig. @@ -563,7 +563,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT return NewRoundTripperFromConfigWithContext(context.Background(), cfg, name, optFuncs...) } -// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the +// NewRoundTripperFromConfigWithContext returns a new HTTP RoundTripper configured for the // given config.HTTPClientConfig and config.HTTPClientOption. // The name is used as go-conntrack metric label. func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) { @@ -647,7 +647,7 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon } if cfg.OAuth2 != nil { - clientSecret, err := toSecret(opts.secretManager, Secret(cfg.OAuth2.ClientSecret), cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef) + clientSecret, err := toSecret(opts.secretManager, cfg.OAuth2.ClientSecret, cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef) if err != nil { return nil, fmt.Errorf("unable to use client secret: %w", err) } @@ -702,7 +702,7 @@ type inlineSecret struct { text string } -func (s *inlineSecret) fetch(ctx context.Context) (string, error) { +func (s *inlineSecret) fetch(context.Context) (string, error) { return s.text, nil } @@ -737,7 +737,7 @@ func (s *fileSecret) immutable() bool { // refSecret fetches a single secret from a secret manager. type refSecret struct { ref string - manager SecretManager + manager SecretManager // manager is expected to be not nil. } func (s *refSecret) fetch(ctx context.Context) (string, error) { @@ -791,20 +791,22 @@ func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials se } func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(req.Header.Get("Authorization")) == 0 { - var authCredentials string - if rt.authCredentials != nil { - var err error - authCredentials, err = rt.authCredentials.fetch(req.Context()) - if err != nil { - return nil, fmt.Errorf("unable to read authorization credentials: %w", err) - } - } + if len(req.Header.Get("Authorization")) != 0 { + return rt.rt.RoundTrip(req) + } - req = cloneRequest(req) - req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials)) + var authCredentials string + if rt.authCredentials != nil { + var err error + authCredentials, err = rt.authCredentials.fetch(req.Context()) + if err != nil { + return nil, fmt.Errorf("unable to read authorization credentials: %w", err) + } } + req = cloneRequest(req) + req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials)) + return rt.rt.RoundTrip(req) } @@ -858,117 +860,118 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() { } type oauth2RoundTripper struct { + mtx sync.RWMutex + lastRT *oauth2.Transport + lastSecret string + + // Required for interaction with Oauth2 server. config *OAuth2 - rt http.RoundTripper - next http.RoundTripper clientSecret secret - lastSecret string - mtx sync.RWMutex opts *httpClientOptions client *http.Client } func NewOAuth2RoundTripper(clientSecret secret, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { + if clientSecret == nil { + clientSecret = &inlineSecret{text: ""} + } + return &oauth2RoundTripper{ - config: config, - next: next, + config: config, + // A correct tokenSource will be added later on. + lastRT: &oauth2.Transport{Base: next}, opts: opts, clientSecret: clientSecret, } } -func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - var ( - secret string - changed bool - ) +func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret string) (client *http.Client, source oauth2.TokenSource, err error) { + tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager)) + if err != nil { + return nil, nil, err + } - // Fetch the secret if it's our first run or always if the secret can change. - if rt.rt == nil || (rt.clientSecret != nil && !rt.clientSecret.immutable()) { - if rt.clientSecret != nil { - var err error - secret, err = rt.clientSecret.fetch(req.Context()) - if err != nil { - return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err) - } + tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) { + return &http.Transport{ + TLSClientConfig: tlsConfig, + Proxy: rt.config.ProxyConfig.Proxy(), + ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(), + DisableKeepAlives: !rt.opts.keepAlivesEnabled, + MaxIdleConns: 20, + MaxIdleConnsPerHost: 1, // see https://github.com/golang/go/issues/13801 + IdleConnTimeout: 10 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + }, nil + } - if !rt.clientSecret.immutable() { - rt.mtx.RLock() - changed = secret != rt.lastSecret - rt.mtx.RUnlock() - } + var t http.RoundTripper + tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager) + if err != nil { + return nil, nil, err + } + if tlsSettings.CA == nil || tlsSettings.CA.immutable() { + t, _ = tlsTransport(tlsConfig) + } else { + t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport) + if err != nil { + return nil, nil, err } + } - if rt.rt == nil { - changed = true - } + if ua := req.UserAgent(); ua != "" { + t = NewUserAgentRoundTripper(ua, t) } - if changed { - config := &clientcredentials.Config{ - ClientID: rt.config.ClientID, - ClientSecret: secret, - Scopes: rt.config.Scopes, - TokenURL: rt.config.TokenURL, - EndpointParams: mapToValues(rt.config.EndpointParams), - } + config := &clientcredentials.Config{ + ClientID: rt.config.ClientID, + ClientSecret: secret, + Scopes: rt.config.Scopes, + TokenURL: rt.config.TokenURL, + EndpointParams: mapToValues(rt.config.EndpointParams), + } + client = &http.Client{Transport: t} + ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) + return client, config.TokenSource(ctx), nil +} - tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager)) - if err != nil { - return nil, err - } +func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + var ( + secret string + needsInit bool + ) - tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) { - return &http.Transport{ - TLSClientConfig: tlsConfig, - Proxy: rt.config.ProxyConfig.Proxy(), - ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(), - DisableKeepAlives: !rt.opts.keepAlivesEnabled, - MaxIdleConns: 20, - MaxIdleConnsPerHost: 1, // see https://github.com/golang/go/issues/13801 - IdleConnTimeout: 10 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - }, nil - } + rt.mtx.RLock() + secret = rt.lastSecret + needsInit = rt.lastRT.Source == nil + rt.mtx.RUnlock() - var t http.RoundTripper - tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager) + // Fetch the secret if it's our first run or always if the secret can change. + if !rt.clientSecret.immutable() || needsInit { + newSecret, err := rt.clientSecret.fetch(req.Context()) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err) } - if tlsSettings.CA == nil || tlsSettings.CA.immutable() { - t, _ = tlsTransport(tlsConfig) - } else { - t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport) + if newSecret != secret || needsInit { + // Secret changed or it's a first run. Rebuilt oauth2 setup. + client, source, err := rt.newOauth2TokenSource(req, newSecret) if err != nil { return nil, err } - } - - if ua := req.UserAgent(); ua != "" { - t = NewUserAgentRoundTripper(ua, t) - } - - client := &http.Client{Transport: t} - ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client) - tokenSource := config.TokenSource(ctx) - rt.mtx.Lock() - rt.lastSecret = secret - rt.rt = &oauth2.Transport{ - Base: rt.next, - Source: tokenSource, - } - if rt.client != nil { - rt.client.CloseIdleConnections() + rt.mtx.Lock() + rt.lastSecret = secret + rt.lastRT.Source = source + if rt.client != nil { + rt.client.CloseIdleConnections() + } + rt.client = client + rt.mtx.Unlock() } - rt.client = client - rt.mtx.Unlock() } rt.mtx.RLock() - currentRT := rt.rt + currentRT := rt.lastRT rt.mtx.RUnlock() return currentRT.RoundTrip(req) } @@ -977,7 +980,7 @@ func (rt *oauth2RoundTripper) CloseIdleConnections() { if rt.client != nil { rt.client.CloseIdleConnections() } - if ci, ok := rt.next.(closeIdler); ok { + if ci, ok := rt.lastRT.Base.(closeIdler); ok { ci.CloseIdleConnections() } } @@ -1019,7 +1022,7 @@ func NewTLSConfig(cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, err return NewTLSConfigWithContext(context.Background(), cfg, optFuncs...) } -// NewTLSConfig creates a new tls.Config from the given TLSConfig. +// NewTLSConfigWithContext creates a new tls.Config from the given TLSConfig. func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) { opts := tlsConfigOptions{} for _, opt := range optFuncs { diff --git a/config/http_config_test.go b/config/http_config_test.go index 5b03a0b5..14e07c22 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -507,46 +507,48 @@ func TestNewClientFromConfig(t *testing.T) { } for _, validConfig := range newClientValidConfig { - testServer, err := newTestServer(validConfig.handler) - if err != nil { - t.Fatal(err.Error()) - } - defer testServer.Close() + t.Run("", func(t *testing.T) { + testServer, err := newTestServer(validConfig.handler) + if err != nil { + t.Fatal(err.Error()) + } + defer testServer.Close() - if validConfig.clientConfig.OAuth2 != nil { - // We don't have access to the test server's URL when configuring the test cases, - // so it has to be specified here. - validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token" - } + if validConfig.clientConfig.OAuth2 != nil { + // We don't have access to the test server's URL when configuring the test cases, + // so it has to be specified here. + validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token" + } - err = validConfig.clientConfig.Validate() - if err != nil { - t.Fatal(err.Error()) - } - client, err := NewClientFromConfig(validConfig.clientConfig, "test") - if err != nil { - t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig) - continue - } + err = validConfig.clientConfig.Validate() + if err != nil { + t.Fatal(err.Error()) + } + client, err := NewClientFromConfig(validConfig.clientConfig, "test") + if err != nil { + t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig) + return + } - response, err := client.Get(testServer.URL) - if err != nil { - t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err) - continue - } + response, err := client.Get(testServer.URL) + if err != nil { + t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err) + return + } - message, err := io.ReadAll(response.Body) - response.Body.Close() - if err != nil { - t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig) - continue - } + message, err := io.ReadAll(response.Body) + response.Body.Close() + if err != nil { + t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig) + return + } - trimMessage := strings.TrimSpace(string(message)) - if ExpectedMessage != trimMessage { - t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v", - ExpectedMessage, trimMessage, validConfig.clientConfig) - } + trimMessage := strings.TrimSpace(string(message)) + if ExpectedMessage != trimMessage { + t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v", + ExpectedMessage, trimMessage, validConfig.clientConfig) + } + }) } }