Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

oauth2RoundTripper: Avoid race condition and readability changes. #634

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 103 additions & 100 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,12 +263,12 @@ func (o *OAuth2) UnmarshalJSON(data []byte) error {
}

// SetDirectory joins any relative file paths with dir.
func (a *OAuth2) SetDirectory(dir string) {
if a == nil {
func (o *OAuth2) SetDirectory(dir string) {
if o == nil {
return
}
a.ClientSecretFile = JoinDir(dir, a.ClientSecretFile)
a.TLSConfig.SetDirectory(dir)
o.ClientSecretFile = JoinDir(dir, o.ClientSecretFile)
o.TLSConfig.SetDirectory(dir)
}

// LoadHTTPConfig parses the YAML input s into a HTTPClientConfig.
Expand Down Expand Up @@ -563,7 +563,7 @@ func NewRoundTripperFromConfig(cfg HTTPClientConfig, name string, optFuncs ...HT
return NewRoundTripperFromConfigWithContext(context.Background(), cfg, name, optFuncs...)
}

// NewRoundTripperFromConfig returns a new HTTP RoundTripper configured for the
// NewRoundTripperFromConfigWithContext returns a new HTTP RoundTripper configured for the
// given config.HTTPClientConfig and config.HTTPClientOption.
// The name is used as go-conntrack metric label.
func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientConfig, name string, optFuncs ...HTTPClientOption) (http.RoundTripper, error) {
Expand Down Expand Up @@ -647,7 +647,7 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon
}

if cfg.OAuth2 != nil {
clientSecret, err := toSecret(opts.secretManager, Secret(cfg.OAuth2.ClientSecret), cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef)
clientSecret, err := toSecret(opts.secretManager, cfg.OAuth2.ClientSecret, cfg.OAuth2.ClientSecretFile, cfg.OAuth2.ClientSecretRef)
if err != nil {
return nil, fmt.Errorf("unable to use client secret: %w", err)
}
Expand Down Expand Up @@ -702,7 +702,7 @@ type inlineSecret struct {
text string
}

func (s *inlineSecret) fetch(ctx context.Context) (string, error) {
func (s *inlineSecret) fetch(context.Context) (string, error) {
return s.text, nil
}

Expand Down Expand Up @@ -737,7 +737,7 @@ func (s *fileSecret) immutable() bool {
// refSecret fetches a single secret from a secret manager.
type refSecret struct {
ref string
manager SecretManager
manager SecretManager // manager is expected to be not nil.
}

func (s *refSecret) fetch(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -791,20 +791,22 @@ func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials se
}

func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
if len(req.Header.Get("Authorization")) == 0 {
var authCredentials string
if rt.authCredentials != nil {
var err error
authCredentials, err = rt.authCredentials.fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read authorization credentials: %w", err)
}
}
if len(req.Header.Get("Authorization")) != 0 {
return rt.rt.RoundTrip(req)
}

req = cloneRequest(req)
req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials))
var authCredentials string
if rt.authCredentials != nil {
var err error
authCredentials, err = rt.authCredentials.fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read authorization credentials: %w", err)
}
}

req = cloneRequest(req)
req.Header.Set("Authorization", fmt.Sprintf("%s %s", rt.authType, authCredentials))

return rt.rt.RoundTrip(req)
}

Expand Down Expand Up @@ -858,117 +860,118 @@ func (rt *basicAuthRoundTripper) CloseIdleConnections() {
}

type oauth2RoundTripper struct {
mtx sync.RWMutex
lastRT *oauth2.Transport
lastSecret string

// Required for interaction with Oauth2 server.
config *OAuth2
rt http.RoundTripper
next http.RoundTripper
clientSecret secret
lastSecret string
mtx sync.RWMutex
opts *httpClientOptions
client *http.Client
}

func NewOAuth2RoundTripper(clientSecret secret, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
if clientSecret == nil {
clientSecret = &inlineSecret{text: ""}
}

return &oauth2RoundTripper{
config: config,
next: next,
config: config,
// A correct tokenSource will be added later on.
lastRT: &oauth2.Transport{Base: next},
opts: opts,
clientSecret: clientSecret,
}
}

func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var (
secret string
changed bool
)
func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret string) (client *http.Client, source oauth2.TokenSource, err error) {
tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager))
if err != nil {
return nil, nil, err
}

// Fetch the secret if it's our first run or always if the secret can change.
if rt.rt == nil || (rt.clientSecret != nil && !rt.clientSecret.immutable()) {
if rt.clientSecret != nil {
var err error
secret, err = rt.clientSecret.fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err)
}
tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
return &http.Transport{
TLSClientConfig: tlsConfig,
Proxy: rt.config.ProxyConfig.Proxy(),
ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(),
DisableKeepAlives: !rt.opts.keepAlivesEnabled,
MaxIdleConns: 20,
MaxIdleConnsPerHost: 1, // see https://github.com/golang/go/issues/13801
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}, nil
}

if !rt.clientSecret.immutable() {
rt.mtx.RLock()
changed = secret != rt.lastSecret
rt.mtx.RUnlock()
}
var t http.RoundTripper
tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager)
if err != nil {
return nil, nil, err
}
if tlsSettings.CA == nil || tlsSettings.CA.immutable() {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport)
if err != nil {
return nil, nil, err
}
}

if rt.rt == nil {
changed = true
}
if ua := req.UserAgent(); ua != "" {
t = NewUserAgentRoundTripper(ua, t)
}

if changed {
config := &clientcredentials.Config{
ClientID: rt.config.ClientID,
ClientSecret: secret,
Scopes: rt.config.Scopes,
TokenURL: rt.config.TokenURL,
EndpointParams: mapToValues(rt.config.EndpointParams),
}
config := &clientcredentials.Config{
ClientID: rt.config.ClientID,
ClientSecret: secret,
Scopes: rt.config.Scopes,
TokenURL: rt.config.TokenURL,
EndpointParams: mapToValues(rt.config.EndpointParams),
}
client = &http.Client{Transport: t}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client)
return client, config.TokenSource(ctx), nil
}

tlsConfig, err := NewTLSConfig(&rt.config.TLSConfig, WithSecretManager(rt.opts.secretManager))
if err != nil {
return nil, err
}
func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
var (
secret string
needsInit bool
)

tlsTransport := func(tlsConfig *tls.Config) (http.RoundTripper, error) {
return &http.Transport{
TLSClientConfig: tlsConfig,
Proxy: rt.config.ProxyConfig.Proxy(),
ProxyConnectHeader: rt.config.ProxyConfig.GetProxyConnectHeader(),
DisableKeepAlives: !rt.opts.keepAlivesEnabled,
MaxIdleConns: 20,
MaxIdleConnsPerHost: 1, // see https://github.com/golang/go/issues/13801
IdleConnTimeout: 10 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}, nil
}
rt.mtx.RLock()
secret = rt.lastSecret
needsInit = rt.lastRT.Source == nil
rt.mtx.RUnlock()

var t http.RoundTripper
tlsSettings, err := rt.config.TLSConfig.roundTripperSettings(rt.opts.secretManager)
// Fetch the secret if it's our first run or always if the secret can change.
if !rt.clientSecret.immutable() || needsInit {
newSecret, err := rt.clientSecret.fetch(req.Context())
if err != nil {
return nil, err
return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err)
}
if tlsSettings.CA == nil || tlsSettings.CA.immutable() {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport)
if newSecret != secret || needsInit {
// Secret changed or it's a first run. Rebuilt oauth2 setup.
client, source, err := rt.newOauth2TokenSource(req, newSecret)
if err != nil {
return nil, err
}
}

if ua := req.UserAgent(); ua != "" {
t = NewUserAgentRoundTripper(ua, t)
}

client := &http.Client{Transport: t}
ctx := context.WithValue(context.Background(), oauth2.HTTPClient, client)
tokenSource := config.TokenSource(ctx)

rt.mtx.Lock()
rt.lastSecret = secret
rt.rt = &oauth2.Transport{
Base: rt.next,
Source: tokenSource,
}
if rt.client != nil {
rt.client.CloseIdleConnections()
rt.mtx.Lock()
rt.lastSecret = secret
rt.lastRT.Source = source
if rt.client != nil {
rt.client.CloseIdleConnections()
}
rt.client = client
rt.mtx.Unlock()
}
rt.client = client
rt.mtx.Unlock()
}

rt.mtx.RLock()
currentRT := rt.rt
currentRT := rt.lastRT
rt.mtx.RUnlock()
return currentRT.RoundTrip(req)
}
Expand All @@ -977,7 +980,7 @@ func (rt *oauth2RoundTripper) CloseIdleConnections() {
if rt.client != nil {
rt.client.CloseIdleConnections()
}
if ci, ok := rt.next.(closeIdler); ok {
if ci, ok := rt.lastRT.Base.(closeIdler); ok {
ci.CloseIdleConnections()
}
}
Expand Down Expand Up @@ -1019,7 +1022,7 @@ func NewTLSConfig(cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, err
return NewTLSConfigWithContext(context.Background(), cfg, optFuncs...)
}

// NewTLSConfig creates a new tls.Config from the given TLSConfig.
// NewTLSConfigWithContext creates a new tls.Config from the given TLSConfig.
func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TLSConfigOption) (*tls.Config, error) {
opts := tlsConfigOptions{}
for _, opt := range optFuncs {
Expand Down
72 changes: 37 additions & 35 deletions config/http_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -507,46 +507,48 @@ func TestNewClientFromConfig(t *testing.T) {
}

for _, validConfig := range newClientValidConfig {
testServer, err := newTestServer(validConfig.handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()
t.Run("", func(t *testing.T) {
bwplotka marked this conversation as resolved.
Show resolved Hide resolved
testServer, err := newTestServer(validConfig.handler)
if err != nil {
t.Fatal(err.Error())
}
defer testServer.Close()

if validConfig.clientConfig.OAuth2 != nil {
// We don't have access to the test server's URL when configuring the test cases,
// so it has to be specified here.
validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token"
}
if validConfig.clientConfig.OAuth2 != nil {
// We don't have access to the test server's URL when configuring the test cases,
// so it has to be specified here.
validConfig.clientConfig.OAuth2.TokenURL = testServer.URL + "/token"
}

err = validConfig.clientConfig.Validate()
if err != nil {
t.Fatal(err.Error())
}
client, err := NewClientFromConfig(validConfig.clientConfig, "test")
if err != nil {
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
continue
}
err = validConfig.clientConfig.Validate()
if err != nil {
t.Fatal(err.Error())
}
client, err := NewClientFromConfig(validConfig.clientConfig, "test")
if err != nil {
t.Errorf("Can't create a client from this config: %+v", validConfig.clientConfig)
return
}

response, err := client.Get(testServer.URL)
if err != nil {
t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err)
continue
}
response, err := client.Get(testServer.URL)
if err != nil {
t.Errorf("Can't connect to the test server using this config: %+v: %v", validConfig.clientConfig, err)
return
}

message, err := io.ReadAll(response.Body)
response.Body.Close()
if err != nil {
t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig)
continue
}
message, err := io.ReadAll(response.Body)
response.Body.Close()
if err != nil {
t.Errorf("Can't read the server response body using this config: %+v", validConfig.clientConfig)
return
}

trimMessage := strings.TrimSpace(string(message))
if ExpectedMessage != trimMessage {
t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v",
ExpectedMessage, trimMessage, validConfig.clientConfig)
}
trimMessage := strings.TrimSpace(string(message))
if ExpectedMessage != trimMessage {
t.Errorf("The expected message (%s) differs from the obtained message (%s) using this config: %+v",
ExpectedMessage, trimMessage, validConfig.clientConfig)
}
})
}
}

Expand Down