diff --git a/config/http_config.go b/config/http_config.go index 75e38bb8..07241327 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -679,7 +679,7 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon if err != nil { return nil, err } - if tlsSettings.CA == nil || tlsSettings.CA.immutable() { + if tlsSettings.CA == nil || tlsSettings.CA.Immutable() { // No need for a RoundTripper that reloads the CA file automatically. return newRT(tlsConfig) } @@ -692,25 +692,29 @@ type SecretManager interface { Fetch(ctx context.Context, secretRef string) (string, error) } -type secret interface { - fetch(ctx context.Context) (string, error) - description() string - immutable() bool +type SecretReader interface { + Fetch(ctx context.Context) (string, error) + Description() string + Immutable() bool } -type inlineSecret struct { +type InlineSecret struct { text string } -func (s *inlineSecret) fetch(context.Context) (string, error) { +func NewInlineSecret(text string) *InlineSecret { + return &InlineSecret{text: text} +} + +func (s *InlineSecret) Fetch(context.Context) (string, error) { return s.text, nil } -func (s *inlineSecret) description() string { +func (s *InlineSecret) Description() string { return "inline" } -func (s *inlineSecret) immutable() bool { +func (s *InlineSecret) Immutable() bool { return true } @@ -718,7 +722,7 @@ type fileSecret struct { file string } -func (s *fileSecret) fetch(ctx context.Context) (string, error) { +func (s *fileSecret) Fetch(ctx context.Context) (string, error) { fileBytes, err := os.ReadFile(s.file) if err != nil { return "", fmt.Errorf("unable to read file %s: %w", s.file, err) @@ -726,39 +730,37 @@ func (s *fileSecret) fetch(ctx context.Context) (string, error) { return strings.TrimSpace(string(fileBytes)), nil } -func (s *fileSecret) description() string { +func (s *fileSecret) Description() string { return fmt.Sprintf("file %s", s.file) } -func (s *fileSecret) immutable() bool { +func (s *fileSecret) Immutable() bool { return false } -// refSecret fetches a single secret from a secret manager. +// refSecret fetches a single secret from a SecretManager. type refSecret struct { ref string manager SecretManager // manager is expected to be not nil. } -func (s *refSecret) fetch(ctx context.Context) (string, error) { +func (s *refSecret) Fetch(ctx context.Context) (string, error) { return s.manager.Fetch(ctx, s.ref) } -func (s *refSecret) description() string { +func (s *refSecret) Description() string { return fmt.Sprintf("ref %s", s.ref) } -func (s *refSecret) immutable() bool { +func (s *refSecret) Immutable() bool { return false } -// toSecret returns a secret from one of the given sources, assuming exactly +// toSecret returns a SecretReader from one of the given sources, assuming exactly // one or none of the sources are provided. -func toSecret(secretManager SecretManager, text Secret, file, ref string) (secret, error) { +func toSecret(secretManager SecretManager, text Secret, file, ref string) (SecretReader, error) { if text != "" { - return &inlineSecret{ - text: string(text), - }, nil + return NewInlineSecret(string(text)), nil } if file != "" { return &fileSecret{ @@ -779,14 +781,14 @@ func toSecret(secretManager SecretManager, text Secret, file, ref string) (secre type authorizationCredentialsRoundTripper struct { authType string - authCredentials secret + authCredentials SecretReader rt http.RoundTripper } // NewAuthorizationCredentialsRoundTripper adds the authorization credentials -// read from the provided secret to a request unless the authorization header +// read from the provided SecretReader to a request unless the authorization header // has already been set. -func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials secret, rt http.RoundTripper) http.RoundTripper { +func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials SecretReader, rt http.RoundTripper) http.RoundTripper { return &authorizationCredentialsRoundTripper{authType, authCredentials, rt} } @@ -798,7 +800,7 @@ func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*h var authCredentials string if rt.authCredentials != nil { var err error - authCredentials, err = rt.authCredentials.fetch(req.Context()) + authCredentials, err = rt.authCredentials.Fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read authorization credentials: %w", err) } @@ -817,14 +819,14 @@ func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() { } type basicAuthRoundTripper struct { - username secret - password secret + username SecretReader + password SecretReader rt http.RoundTripper } // NewBasicAuthRoundTripper will apply a BASIC auth authorization header to a request unless it has // already been set. -func NewBasicAuthRoundTripper(username secret, password secret, rt http.RoundTripper) http.RoundTripper { +func NewBasicAuthRoundTripper(username SecretReader, password SecretReader, rt http.RoundTripper) http.RoundTripper { return &basicAuthRoundTripper{username, password, rt} } @@ -836,14 +838,14 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e var password string if rt.username != nil { var err error - username, err = rt.username.fetch(req.Context()) + username, err = rt.username.Fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read basic auth username: %w", err) } } if rt.password != nil { var err error - password, err = rt.password.fetch(req.Context()) + password, err = rt.password.Fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read basic auth password: %w", err) } @@ -866,14 +868,14 @@ type oauth2RoundTripper struct { // Required for interaction with Oauth2 server. config *OAuth2 - clientSecret secret + clientSecret SecretReader opts *httpClientOptions client *http.Client } -func NewOAuth2RoundTripper(clientSecret secret, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { +func NewOAuth2RoundTripper(clientSecret SecretReader, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper { if clientSecret == nil { - clientSecret = &inlineSecret{text: ""} + clientSecret = NewInlineSecret("") } return &oauth2RoundTripper{ @@ -910,7 +912,7 @@ func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret str if err != nil { return nil, nil, err } - if tlsSettings.CA == nil || tlsSettings.CA.immutable() { + if tlsSettings.CA == nil || tlsSettings.CA.Immutable() { t, _ = tlsTransport(tlsConfig) } else { t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport) @@ -947,8 +949,8 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro rt.mtx.RUnlock() // Fetch the secret if it's our first run or always if the secret can change. - if !rt.clientSecret.immutable() || needsInit { - newSecret, err := rt.clientSecret.fetch(req.Context()) + if !rt.clientSecret.Immutable() || needsInit { + newSecret, err := rt.clientSecret.Fetch(req.Context()) if err != nil { return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err) } @@ -1052,12 +1054,12 @@ func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TL return nil, fmt.Errorf("unable to use CA cert: %w", err) } if caSecret != nil { - ca, err := caSecret.fetch(ctx) + ca, err := caSecret.Fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read CA cert: %w", err) } if !updateRootCA(tlsConfig, []byte(ca)) { - return nil, fmt.Errorf("unable to use specified CA cert %s", caSecret.description()) + return nil, fmt.Errorf("unable to use specified CA cert %s", caSecret.Description()) } } @@ -1198,7 +1200,7 @@ func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager Secr return nil, fmt.Errorf("unable to use client cert: %w", err) } if certSecret != nil { - certData, err = certSecret.fetch(ctx) + certData, err = certSecret.Fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read specified client cert: %w", err) } @@ -1209,7 +1211,7 @@ func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager Secr return nil, fmt.Errorf("unable to use client key: %w", err) } if keySecret != nil { - keyData, err = keySecret.fetch(ctx) + keyData, err = keySecret.Fetch(ctx) if err != nil { return nil, fmt.Errorf("unable to read specified client key: %w", err) } @@ -1217,7 +1219,7 @@ func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager Secr cert, err := tls.X509KeyPair([]byte(certData), []byte(keyData)) if err != nil { - return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", certSecret.description(), keySecret.description(), err) + return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", certSecret.Description(), keySecret.Description(), err) } return &cert, nil @@ -1250,9 +1252,9 @@ type tlsRoundTripper struct { } type TLSRoundTripperSettings struct { - CA secret - Cert secret - Key secret + CA SecretReader + Cert SecretReader + Key SecretReader } func NewTLSRoundTripper( @@ -1292,7 +1294,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash(ctx context.Context) ([]byte, []byt var caBytes, certBytes, keyBytes []byte if t.settings.CA != nil { - ca, err := t.settings.CA.fetch(ctx) + ca, err := t.settings.CA.Fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read CA cert: %w", err) } @@ -1300,7 +1302,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash(ctx context.Context) ([]byte, []byt } if t.settings.Cert != nil { - cert, err := t.settings.Cert.fetch(ctx) + cert, err := t.settings.Cert.Fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read client cert: %w", err) } @@ -1308,7 +1310,7 @@ func (t *tlsRoundTripper) getTLSDataWithHash(ctx context.Context) ([]byte, []byt } if t.settings.Key != nil { - key, err := t.settings.Key.fetch(ctx) + key, err := t.settings.Key.Fetch(ctx) if err != nil { return nil, nil, nil, nil, fmt.Errorf("unable to read client key: %w", err) } @@ -1353,7 +1355,7 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { // using GetClientCertificate. tlsConfig := t.tlsConfig.Clone() if !updateRootCA(tlsConfig, caData) { - return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CA.description()) + return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CA.Description()) } rt, err = t.newRT(tlsConfig) if err != nil { diff --git a/config/http_config_test.go b/config/http_config_test.go index 14e07c22..67b7408d 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -727,7 +727,7 @@ func TestBearerAuthRoundTripper(t *testing.T) { }, nil, nil) // Normal flow. - bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", &inlineSecret{text: BearerToken}, fakeRoundTripper) + bearerAuthRoundTripper := NewAuthorizationCredentialsRoundTripper("Bearer", NewInlineSecret(BearerToken), fakeRoundTripper) request, _ := http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("User-Agent", "Douglas Adams mind") _, err := bearerAuthRoundTripper.RoundTrip(request) @@ -736,7 +736,7 @@ func TestBearerAuthRoundTripper(t *testing.T) { } // Should honor already Authorization header set. - bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", &inlineSecret{text: newBearerToken}, fakeRoundTripper) + bearerAuthRoundTripperShouldNotModifyExistingAuthorization := NewAuthorizationCredentialsRoundTripper("Bearer", NewInlineSecret(newBearerToken), fakeRoundTripper) request, _ = http.NewRequest("GET", "/hitchhiker", nil) request.Header.Set("Authorization", ExpectedBearer) _, err = bearerAuthRoundTripperShouldNotModifyExistingAuthorization.RoundTrip(request) @@ -936,7 +936,7 @@ func TestBasicAuthNoPassword(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(context.Background()); username != "user" { + if username, _ := rt.username.Fetch(context.Background()); username != "user" { t.Errorf("Bad HTTP client username: %s", username) } if rt.password != nil { @@ -962,7 +962,7 @@ func TestBasicAuthNoUsername(t *testing.T) { if rt.username != nil { t.Errorf("Got unexpected username") } - if password, _ := rt.password.fetch(context.Background()); password != "secret" { + if password, _ := rt.password.Fetch(context.Background()); password != "secret" { t.Errorf("Unexpected HTTP client password: %s", password) } } @@ -982,10 +982,10 @@ func TestBasicAuthPasswordFile(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(context.Background()); username != "user" { + if username, _ := rt.username.Fetch(context.Background()); username != "user" { t.Errorf("Bad HTTP client username: %s", username) } - if password, _ := rt.password.fetch(context.Background()); password != "foobar" { + if password, _ := rt.password.Fetch(context.Background()); password != "foobar" { t.Errorf("Bad HTTP client password: %s", password) } } @@ -1023,10 +1023,10 @@ func TestBasicAuthSecretManager(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(context.Background()); username != "user" { + if username, _ := rt.username.Fetch(context.Background()); username != "user" { t.Errorf("Bad HTTP client username: %s", username) } - if password, _ := rt.password.fetch(context.Background()); password != "foobar" { + if password, _ := rt.password.Fetch(context.Background()); password != "foobar" { t.Errorf("Bad HTTP client password: %s", password) } } @@ -1052,10 +1052,10 @@ func TestBasicAuthSecretManagerNotFound(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if _, err := rt.username.fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret admin") { + if _, err := rt.username.Fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret admin") { t.Errorf("Unexpected error message: %s", err) } - if _, err := rt.password.fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret pass") { + if _, err := rt.password.Fetch(context.Background()); !strings.Contains(err.Error(), "unknown secret pass") { t.Errorf("Unexpected error message: %s", err) } } @@ -1075,10 +1075,10 @@ func TestBasicUsernameFile(t *testing.T) { t.Fatalf("Error casting to basic auth transport, %v", client.Transport) } - if username, _ := rt.username.fetch(context.Background()); username != "testuser" { + if username, _ := rt.username.Fetch(context.Background()); username != "testuser" { t.Errorf("Bad HTTP client username: %s", username) } - if password, _ := rt.password.fetch(context.Background()); password != "foobar" { + if password, _ := rt.password.Fetch(context.Background()); password != "foobar" { t.Errorf("Bad HTTP client passwordFile: %s", password) } } @@ -1629,7 +1629,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&inlineSecret{text: string(expectedConfig.ClientSecret)}, &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) + rt := NewOAuth2RoundTripper(NewInlineSecret(string(expectedConfig.ClientSecret)), &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt, @@ -1799,7 +1799,7 @@ endpoint_params: t.Fatalf("Got unmarshalled config %v, expected %v", unmarshalledConfig, expectedConfig) } - rt := NewOAuth2RoundTripper(&inlineSecret{text: string(expectedConfig.ClientSecret)}, &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) + rt := NewOAuth2RoundTripper(NewInlineSecret(string(expectedConfig.ClientSecret)), &expectedConfig, http.DefaultTransport, &defaultHTTPClientOptions) client := http.Client{ Transport: rt,