Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose secret as SecretReader and InlineSecret from config package #650

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 50 additions & 48 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon
if err != nil {
return nil, err
}
if tlsSettings.CA == nil || tlsSettings.CA.immutable() {
if tlsSettings.CA == nil || tlsSettings.CA.Immutable() {
// No need for a RoundTripper that reloads the CA file automatically.
return newRT(tlsConfig)
}
Expand All @@ -692,73 +692,75 @@ type SecretManager interface {
Fetch(ctx context.Context, secretRef string) (string, error)
}

type secret interface {
fetch(ctx context.Context) (string, error)
description() string
immutable() bool
type SecretReader interface {
Fetch(ctx context.Context) (string, error)
Description() string
Immutable() bool
}

type inlineSecret struct {
type InlineSecret struct {
text string
}

func (s *inlineSecret) fetch(context.Context) (string, error) {
func NewInlineSecret(text string) *InlineSecret {
return &InlineSecret{text: text}
}

func (s *InlineSecret) Fetch(context.Context) (string, error) {
return s.text, nil
}

func (s *inlineSecret) description() string {
func (s *InlineSecret) Description() string {
return "inline"
}

func (s *inlineSecret) immutable() bool {
func (s *InlineSecret) Immutable() bool {
return true
}

type fileSecret struct {
file string
}

func (s *fileSecret) fetch(ctx context.Context) (string, error) {
func (s *fileSecret) Fetch(ctx context.Context) (string, error) {
fileBytes, err := os.ReadFile(s.file)
if err != nil {
return "", fmt.Errorf("unable to read file %s: %w", s.file, err)
}
return strings.TrimSpace(string(fileBytes)), nil
}

func (s *fileSecret) description() string {
func (s *fileSecret) Description() string {
return fmt.Sprintf("file %s", s.file)
}

func (s *fileSecret) immutable() bool {
func (s *fileSecret) Immutable() bool {
return false
}

// refSecret fetches a single secret from a secret manager.
// refSecret fetches a single secret from a SecretManager.
type refSecret struct {
ref string
manager SecretManager // manager is expected to be not nil.
}

func (s *refSecret) fetch(ctx context.Context) (string, error) {
func (s *refSecret) Fetch(ctx context.Context) (string, error) {
return s.manager.Fetch(ctx, s.ref)
}

func (s *refSecret) description() string {
func (s *refSecret) Description() string {
return fmt.Sprintf("ref %s", s.ref)
}

func (s *refSecret) immutable() bool {
func (s *refSecret) Immutable() bool {
return false
}

// toSecret returns a secret from one of the given sources, assuming exactly
// toSecret returns a SecretReader from one of the given sources, assuming exactly
// one or none of the sources are provided.
func toSecret(secretManager SecretManager, text Secret, file, ref string) (secret, error) {
func toSecret(secretManager SecretManager, text Secret, file, ref string) (SecretReader, error) {
if text != "" {
return &inlineSecret{
text: string(text),
}, nil
return NewInlineSecret(string(text)), nil
}
if file != "" {
return &fileSecret{
Expand All @@ -779,14 +781,14 @@ func toSecret(secretManager SecretManager, text Secret, file, ref string) (secre

type authorizationCredentialsRoundTripper struct {
authType string
authCredentials secret
authCredentials SecretReader
rt http.RoundTripper
}

// NewAuthorizationCredentialsRoundTripper adds the authorization credentials
// read from the provided secret to a request unless the authorization header
// read from the provided SecretReader to a request unless the authorization header
// has already been set.
func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials secret, rt http.RoundTripper) http.RoundTripper {
func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials SecretReader, rt http.RoundTripper) http.RoundTripper {
return &authorizationCredentialsRoundTripper{authType, authCredentials, rt}
}

Expand All @@ -798,7 +800,7 @@ func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*h
var authCredentials string
if rt.authCredentials != nil {
var err error
authCredentials, err = rt.authCredentials.fetch(req.Context())
authCredentials, err = rt.authCredentials.Fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read authorization credentials: %w", err)
}
Expand All @@ -817,14 +819,14 @@ func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() {
}

type basicAuthRoundTripper struct {
username secret
password secret
username SecretReader
password SecretReader
rt http.RoundTripper
}

// NewBasicAuthRoundTripper will apply a BASIC auth authorization header to a request unless it has
// already been set.
func NewBasicAuthRoundTripper(username secret, password secret, rt http.RoundTripper) http.RoundTripper {
func NewBasicAuthRoundTripper(username SecretReader, password SecretReader, rt http.RoundTripper) http.RoundTripper {
return &basicAuthRoundTripper{username, password, rt}
}

Expand All @@ -836,14 +838,14 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
var password string
if rt.username != nil {
var err error
username, err = rt.username.fetch(req.Context())
username, err = rt.username.Fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read basic auth username: %w", err)
}
}
if rt.password != nil {
var err error
password, err = rt.password.fetch(req.Context())
password, err = rt.password.Fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read basic auth password: %w", err)
}
Expand All @@ -866,14 +868,14 @@ type oauth2RoundTripper struct {

// Required for interaction with Oauth2 server.
config *OAuth2
clientSecret secret
clientSecret SecretReader
opts *httpClientOptions
client *http.Client
}

func NewOAuth2RoundTripper(clientSecret secret, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
func NewOAuth2RoundTripper(clientSecret SecretReader, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
if clientSecret == nil {
clientSecret = &inlineSecret{text: ""}
clientSecret = NewInlineSecret("")
}

return &oauth2RoundTripper{
Expand Down Expand Up @@ -910,7 +912,7 @@ func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret str
if err != nil {
return nil, nil, err
}
if tlsSettings.CA == nil || tlsSettings.CA.immutable() {
if tlsSettings.CA == nil || tlsSettings.CA.Immutable() {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport)
Expand Down Expand Up @@ -947,8 +949,8 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
rt.mtx.RUnlock()

// Fetch the secret if it's our first run or always if the secret can change.
if !rt.clientSecret.immutable() || needsInit {
newSecret, err := rt.clientSecret.fetch(req.Context())
if !rt.clientSecret.Immutable() || needsInit {
newSecret, err := rt.clientSecret.Fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err)
}
Expand Down Expand Up @@ -1052,12 +1054,12 @@ func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TL
return nil, fmt.Errorf("unable to use CA cert: %w", err)
}
if caSecret != nil {
ca, err := caSecret.fetch(ctx)
ca, err := caSecret.Fetch(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read CA cert: %w", err)
}
if !updateRootCA(tlsConfig, []byte(ca)) {
return nil, fmt.Errorf("unable to use specified CA cert %s", caSecret.description())
return nil, fmt.Errorf("unable to use specified CA cert %s", caSecret.Description())
}
}

Expand Down Expand Up @@ -1198,7 +1200,7 @@ func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager Secr
return nil, fmt.Errorf("unable to use client cert: %w", err)
}
if certSecret != nil {
certData, err = certSecret.fetch(ctx)
certData, err = certSecret.Fetch(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert: %w", err)
}
Expand All @@ -1209,15 +1211,15 @@ func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager Secr
return nil, fmt.Errorf("unable to use client key: %w", err)
}
if keySecret != nil {
keyData, err = keySecret.fetch(ctx)
keyData, err = keySecret.Fetch(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read specified client key: %w", err)
}
}

cert, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", certSecret.description(), keySecret.description(), err)
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", certSecret.Description(), keySecret.Description(), err)
}

return &cert, nil
Expand Down Expand Up @@ -1250,9 +1252,9 @@ type tlsRoundTripper struct {
}

type TLSRoundTripperSettings struct {
CA secret
Cert secret
Key secret
CA SecretReader
Cert SecretReader
Key SecretReader
}

func NewTLSRoundTripper(
Expand Down Expand Up @@ -1292,23 +1294,23 @@ func (t *tlsRoundTripper) getTLSDataWithHash(ctx context.Context) ([]byte, []byt
var caBytes, certBytes, keyBytes []byte

if t.settings.CA != nil {
ca, err := t.settings.CA.fetch(ctx)
ca, err := t.settings.CA.Fetch(ctx)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to read CA cert: %w", err)
}
caBytes = []byte(ca)
}

if t.settings.Cert != nil {
cert, err := t.settings.Cert.fetch(ctx)
cert, err := t.settings.Cert.Fetch(ctx)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to read client cert: %w", err)
}
certBytes = []byte(cert)
}

if t.settings.Key != nil {
key, err := t.settings.Key.fetch(ctx)
key, err := t.settings.Key.Fetch(ctx)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to read client key: %w", err)
}
Expand Down Expand Up @@ -1353,7 +1355,7 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// using GetClientCertificate.
tlsConfig := t.tlsConfig.Clone()
if !updateRootCA(tlsConfig, caData) {
return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CA.description())
return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CA.Description())
}
rt, err = t.newRT(tlsConfig)
if err != nil {
Expand Down
Loading