Skip to content

Commit

Permalink
Merge pull request #650 from pracucci/export-secret
Browse files Browse the repository at this point in the history
Expose secret as SecretReader and InlineSecret from config package
  • Loading branch information
gotjosh committed Jun 7, 2024
2 parents d726751 + 43e45c3 commit 92fc65e
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 62 deletions.
98 changes: 50 additions & 48 deletions config/http_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon
if err != nil {
return nil, err
}
if tlsSettings.CA == nil || tlsSettings.CA.immutable() {
if tlsSettings.CA == nil || tlsSettings.CA.Immutable() {
// No need for a RoundTripper that reloads the CA file automatically.
return newRT(tlsConfig)
}
Expand All @@ -692,73 +692,75 @@ type SecretManager interface {
Fetch(ctx context.Context, secretRef string) (string, error)
}

type secret interface {
fetch(ctx context.Context) (string, error)
description() string
immutable() bool
type SecretReader interface {
Fetch(ctx context.Context) (string, error)
Description() string
Immutable() bool
}

type inlineSecret struct {
type InlineSecret struct {
text string
}

func (s *inlineSecret) fetch(context.Context) (string, error) {
func NewInlineSecret(text string) *InlineSecret {
return &InlineSecret{text: text}
}

func (s *InlineSecret) Fetch(context.Context) (string, error) {
return s.text, nil
}

func (s *inlineSecret) description() string {
func (s *InlineSecret) Description() string {
return "inline"
}

func (s *inlineSecret) immutable() bool {
func (s *InlineSecret) Immutable() bool {
return true
}

type fileSecret struct {
file string
}

func (s *fileSecret) fetch(ctx context.Context) (string, error) {
func (s *fileSecret) Fetch(ctx context.Context) (string, error) {
fileBytes, err := os.ReadFile(s.file)
if err != nil {
return "", fmt.Errorf("unable to read file %s: %w", s.file, err)
}
return strings.TrimSpace(string(fileBytes)), nil
}

func (s *fileSecret) description() string {
func (s *fileSecret) Description() string {
return fmt.Sprintf("file %s", s.file)
}

func (s *fileSecret) immutable() bool {
func (s *fileSecret) Immutable() bool {
return false
}

// refSecret fetches a single secret from a secret manager.
// refSecret fetches a single secret from a SecretManager.
type refSecret struct {
ref string
manager SecretManager // manager is expected to be not nil.
}

func (s *refSecret) fetch(ctx context.Context) (string, error) {
func (s *refSecret) Fetch(ctx context.Context) (string, error) {
return s.manager.Fetch(ctx, s.ref)
}

func (s *refSecret) description() string {
func (s *refSecret) Description() string {
return fmt.Sprintf("ref %s", s.ref)
}

func (s *refSecret) immutable() bool {
func (s *refSecret) Immutable() bool {
return false
}

// toSecret returns a secret from one of the given sources, assuming exactly
// toSecret returns a SecretReader from one of the given sources, assuming exactly
// one or none of the sources are provided.
func toSecret(secretManager SecretManager, text Secret, file, ref string) (secret, error) {
func toSecret(secretManager SecretManager, text Secret, file, ref string) (SecretReader, error) {
if text != "" {
return &inlineSecret{
text: string(text),
}, nil
return NewInlineSecret(string(text)), nil
}
if file != "" {
return &fileSecret{
Expand All @@ -779,14 +781,14 @@ func toSecret(secretManager SecretManager, text Secret, file, ref string) (secre

type authorizationCredentialsRoundTripper struct {
authType string
authCredentials secret
authCredentials SecretReader
rt http.RoundTripper
}

// NewAuthorizationCredentialsRoundTripper adds the authorization credentials
// read from the provided secret to a request unless the authorization header
// read from the provided SecretReader to a request unless the authorization header
// has already been set.
func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials secret, rt http.RoundTripper) http.RoundTripper {
func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials SecretReader, rt http.RoundTripper) http.RoundTripper {
return &authorizationCredentialsRoundTripper{authType, authCredentials, rt}
}

Expand All @@ -798,7 +800,7 @@ func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*h
var authCredentials string
if rt.authCredentials != nil {
var err error
authCredentials, err = rt.authCredentials.fetch(req.Context())
authCredentials, err = rt.authCredentials.Fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read authorization credentials: %w", err)
}
Expand All @@ -817,14 +819,14 @@ func (rt *authorizationCredentialsRoundTripper) CloseIdleConnections() {
}

type basicAuthRoundTripper struct {
username secret
password secret
username SecretReader
password SecretReader
rt http.RoundTripper
}

// NewBasicAuthRoundTripper will apply a BASIC auth authorization header to a request unless it has
// already been set.
func NewBasicAuthRoundTripper(username secret, password secret, rt http.RoundTripper) http.RoundTripper {
func NewBasicAuthRoundTripper(username SecretReader, password SecretReader, rt http.RoundTripper) http.RoundTripper {
return &basicAuthRoundTripper{username, password, rt}
}

Expand All @@ -836,14 +838,14 @@ func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
var password string
if rt.username != nil {
var err error
username, err = rt.username.fetch(req.Context())
username, err = rt.username.Fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read basic auth username: %w", err)
}
}
if rt.password != nil {
var err error
password, err = rt.password.fetch(req.Context())
password, err = rt.password.Fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read basic auth password: %w", err)
}
Expand All @@ -866,14 +868,14 @@ type oauth2RoundTripper struct {

// Required for interaction with Oauth2 server.
config *OAuth2
clientSecret secret
clientSecret SecretReader
opts *httpClientOptions
client *http.Client
}

func NewOAuth2RoundTripper(clientSecret secret, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
func NewOAuth2RoundTripper(clientSecret SecretReader, config *OAuth2, next http.RoundTripper, opts *httpClientOptions) http.RoundTripper {
if clientSecret == nil {
clientSecret = &inlineSecret{text: ""}
clientSecret = NewInlineSecret("")
}

return &oauth2RoundTripper{
Expand Down Expand Up @@ -910,7 +912,7 @@ func (rt *oauth2RoundTripper) newOauth2TokenSource(req *http.Request, secret str
if err != nil {
return nil, nil, err
}
if tlsSettings.CA == nil || tlsSettings.CA.immutable() {
if tlsSettings.CA == nil || tlsSettings.CA.Immutable() {
t, _ = tlsTransport(tlsConfig)
} else {
t, err = NewTLSRoundTripperWithContext(req.Context(), tlsConfig, tlsSettings, tlsTransport)
Expand Down Expand Up @@ -947,8 +949,8 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro
rt.mtx.RUnlock()

// Fetch the secret if it's our first run or always if the secret can change.
if !rt.clientSecret.immutable() || needsInit {
newSecret, err := rt.clientSecret.fetch(req.Context())
if !rt.clientSecret.Immutable() || needsInit {
newSecret, err := rt.clientSecret.Fetch(req.Context())
if err != nil {
return nil, fmt.Errorf("unable to read oauth2 client secret: %w", err)
}
Expand Down Expand Up @@ -1052,12 +1054,12 @@ func NewTLSConfigWithContext(ctx context.Context, cfg *TLSConfig, optFuncs ...TL
return nil, fmt.Errorf("unable to use CA cert: %w", err)
}
if caSecret != nil {
ca, err := caSecret.fetch(ctx)
ca, err := caSecret.Fetch(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read CA cert: %w", err)
}
if !updateRootCA(tlsConfig, []byte(ca)) {
return nil, fmt.Errorf("unable to use specified CA cert %s", caSecret.description())
return nil, fmt.Errorf("unable to use specified CA cert %s", caSecret.Description())
}
}

Expand Down Expand Up @@ -1198,7 +1200,7 @@ func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager Secr
return nil, fmt.Errorf("unable to use client cert: %w", err)
}
if certSecret != nil {
certData, err = certSecret.fetch(ctx)
certData, err = certSecret.Fetch(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read specified client cert: %w", err)
}
Expand All @@ -1209,15 +1211,15 @@ func (c *TLSConfig) getClientCertificate(ctx context.Context, secretManager Secr
return nil, fmt.Errorf("unable to use client key: %w", err)
}
if keySecret != nil {
keyData, err = keySecret.fetch(ctx)
keyData, err = keySecret.Fetch(ctx)
if err != nil {
return nil, fmt.Errorf("unable to read specified client key: %w", err)
}
}

cert, err := tls.X509KeyPair([]byte(certData), []byte(keyData))
if err != nil {
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", certSecret.description(), keySecret.description(), err)
return nil, fmt.Errorf("unable to use specified client cert (%s) & key (%s): %w", certSecret.Description(), keySecret.Description(), err)
}

return &cert, nil
Expand Down Expand Up @@ -1250,9 +1252,9 @@ type tlsRoundTripper struct {
}

type TLSRoundTripperSettings struct {
CA secret
Cert secret
Key secret
CA SecretReader
Cert SecretReader
Key SecretReader
}

func NewTLSRoundTripper(
Expand Down Expand Up @@ -1292,23 +1294,23 @@ func (t *tlsRoundTripper) getTLSDataWithHash(ctx context.Context) ([]byte, []byt
var caBytes, certBytes, keyBytes []byte

if t.settings.CA != nil {
ca, err := t.settings.CA.fetch(ctx)
ca, err := t.settings.CA.Fetch(ctx)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to read CA cert: %w", err)
}
caBytes = []byte(ca)
}

if t.settings.Cert != nil {
cert, err := t.settings.Cert.fetch(ctx)
cert, err := t.settings.Cert.Fetch(ctx)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to read client cert: %w", err)
}
certBytes = []byte(cert)
}

if t.settings.Key != nil {
key, err := t.settings.Key.fetch(ctx)
key, err := t.settings.Key.Fetch(ctx)
if err != nil {
return nil, nil, nil, nil, fmt.Errorf("unable to read client key: %w", err)
}
Expand Down Expand Up @@ -1353,7 +1355,7 @@ func (t *tlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
// using GetClientCertificate.
tlsConfig := t.tlsConfig.Clone()
if !updateRootCA(tlsConfig, caData) {
return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CA.description())
return nil, fmt.Errorf("unable to use specified CA cert %s", t.settings.CA.Description())
}
rt, err = t.newRT(tlsConfig)
if err != nil {
Expand Down
Loading

0 comments on commit 92fc65e

Please sign in to comment.