diff --git a/plugin/pkg/client/auth/exec/exec.go b/plugin/pkg/client/auth/exec/exec.go index d37dfbf732..73876f6887 100644 --- a/plugin/pkg/client/auth/exec/exec.go +++ b/plugin/pkg/client/auth/exec/exec.go @@ -199,14 +199,18 @@ func newAuthenticator(c *cache, isTerminalFunc func(int) bool, config *api.ExecC now: time.Now, environ: os.Environ, - defaultDialer: defaultDialer, - connTracker: connTracker, + connTracker: connTracker, } for _, env := range config.Env { a.env = append(a.env, env.Name+"="+env.Value) } + // these functions are made comparable and stored in the cache so that repeated clientset + // construction with the same rest.Config results in a single TLS cache and Authenticator + a.getCert = &transport.GetCertHolder{GetCert: a.cert} + a.dial = &transport.DialHolder{Dial: defaultDialer.DialContext} + return c.put(key, a), nil } @@ -261,8 +265,6 @@ type Authenticator struct { now func() time.Time environ func() []string - // defaultDialer is used for clients which don't specify a custom dialer - defaultDialer *connrotation.Dialer // connTracker tracks all connections opened that we need to close when rotating a client certificate connTracker *connrotation.ConnectionTracker @@ -273,6 +275,12 @@ type Authenticator struct { mu sync.Mutex cachedCreds *credentials exp time.Time + + // getCert makes Authenticator.cert comparable to support TLS config caching + getCert *transport.GetCertHolder + // dial is used for clients which do not specify a custom dialer + // it is comparable to support TLS config caching + dial *transport.DialHolder } type credentials struct { @@ -300,18 +308,20 @@ func (a *Authenticator) UpdateTransportConfig(c *transport.Config) error { if c.HasCertCallback() { return errors.New("can't add TLS certificate callback: transport.Config.TLS.GetCert already set") } - c.TLS.GetCert = a.cert + c.TLS.GetCert = a.getCert.GetCert + c.TLS.GetCertHolder = a.getCert // comparable for TLS config caching - var d *connrotation.Dialer if c.Dial != nil { // if c has a custom dialer, we have to wrap it - d = connrotation.NewDialerWithTracker(c.Dial, a.connTracker) + // TLS config caching is not supported for this config + d := connrotation.NewDialerWithTracker(c.Dial, a.connTracker) + c.Dial = d.DialContext + c.DialHolder = nil } else { - d = a.defaultDialer + c.Dial = a.dial.Dial + c.DialHolder = a.dial // comparable for TLS config caching } - c.Dial = d.DialContext - return nil } diff --git a/plugin/pkg/client/auth/exec/exec_cache_test.go b/plugin/pkg/client/auth/exec/exec_cache_test.go new file mode 100644 index 0000000000..ecf84262a1 --- /dev/null +++ b/plugin/pkg/client/auth/exec/exec_cache_test.go @@ -0,0 +1,106 @@ +/* +Copyright 2022 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package exec_test // separate package to prevent circular import + +import ( + "context" + "testing" + "time" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + utilnet "k8s.io/apimachinery/pkg/util/net" + clientset "k8s.io/client-go/kubernetes" + "k8s.io/client-go/rest" + clientcmdapi "k8s.io/client-go/tools/clientcmd/api" +) + +// TestExecTLSCache asserts the semantics of the TLS cache when exec auth is used. +// +// In particular, when: +// - multiple identical rest configs exist as distinct objects, and +// - these rest configs use exec auth, and +// - these rest configs are used to create distinct clientsets, then +// +// the underlying TLS config is shared between those clientsets. +func TestExecTLSCache(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + t.Cleanup(cancel) + + config1 := &rest.Config{ + Host: "https://localhost", + ExecProvider: &clientcmdapi.ExecConfig{ + Command: "./testdata/test-plugin.sh", + APIVersion: "client.authentication.k8s.io/v1", + InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode, + }, + } + client1 := clientset.NewForConfigOrDie(config1) + + config2 := &rest.Config{ + Host: "https://localhost", + ExecProvider: &clientcmdapi.ExecConfig{ + Command: "./testdata/test-plugin.sh", + APIVersion: "client.authentication.k8s.io/v1", + InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode, + }, + } + client2 := clientset.NewForConfigOrDie(config2) + + config3 := &rest.Config{ + Host: "https://localhost", + ExecProvider: &clientcmdapi.ExecConfig{ + Command: "./testdata/test-plugin.sh", + Args: []string{"make this exec auth different"}, + APIVersion: "client.authentication.k8s.io/v1", + InteractiveMode: clientcmdapi.IfAvailableExecInteractiveMode, + }, + } + client3 := clientset.NewForConfigOrDie(config3) + + _, _ = client1.CoreV1().Nodes().List(ctx, metav1.ListOptions{}) + _, _ = client2.CoreV1().Namespaces().List(ctx, metav1.ListOptions{}) + _, _ = client3.CoreV1().PersistentVolumes().List(ctx, metav1.ListOptions{}) + + rt1 := client1.RESTClient().(*rest.RESTClient).Client.Transport + rt2 := client2.RESTClient().(*rest.RESTClient).Client.Transport + rt3 := client3.RESTClient().(*rest.RESTClient).Client.Transport + + tlsConfig1, err := utilnet.TLSClientConfig(rt1) + if err != nil { + t.Fatal(err) + } + tlsConfig2, err := utilnet.TLSClientConfig(rt2) + if err != nil { + t.Fatal(err) + } + tlsConfig3, err := utilnet.TLSClientConfig(rt3) + if err != nil { + t.Fatal(err) + } + + if tlsConfig1 == nil || tlsConfig2 == nil || tlsConfig3 == nil { + t.Fatal("expected non-nil TLS configs") + } + + if tlsConfig1 != tlsConfig2 { + t.Fatal("expected the same TLS config for matching exec config via rest config") + } + + if tlsConfig1 == tlsConfig3 { + t.Fatal("expected different TLS config for non-matching exec config via rest config") + } +} diff --git a/transport/cache.go b/transport/cache.go index 214f0a79cf..b4f8dab0c9 100644 --- a/transport/cache.go +++ b/transport/cache.go @@ -17,6 +17,7 @@ limitations under the License. package transport import ( + "context" "fmt" "net" "net/http" @@ -55,6 +56,9 @@ type tlsCacheKey struct { serverName string nextProtos string disableCompression bool + // these functions are wrapped to allow them to be used as map keys + getCert *GetCertHolder + dial *DialHolder } func (t tlsCacheKey) String() string { @@ -62,7 +66,8 @@ func (t tlsCacheKey) String() string { if len(t.keyData) > 0 { keyText = "" } - return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t", t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression) + return fmt.Sprintf("insecure:%v, caData:%#v, certData:%#v, keyData:%s, serverName:%s, disableCompression:%t, getCert:%p, dial:%p", + t.insecure, t.caData, t.certData, keyText, t.serverName, t.disableCompression, t.getCert, t.dial) } func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { @@ -92,8 +97,10 @@ func (c *tlsTransportCache) get(config *Config) (http.RoundTripper, error) { return http.DefaultTransport, nil } - dial := config.Dial - if dial == nil { + var dial func(ctx context.Context, network, address string) (net.Conn, error) + if config.Dial != nil { + dial = config.Dial + } else { dial = (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, @@ -138,10 +145,18 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { return tlsCacheKey{}, false, err } - if c.TLS.GetCert != nil || c.Dial != nil || c.Proxy != nil { + if c.Proxy != nil { // cannot determine equality for functions return tlsCacheKey{}, false, nil } + if c.Dial != nil && c.DialHolder == nil { + // cannot determine equality for dial function that doesn't have non-nil DialHolder set as well + return tlsCacheKey{}, false, nil + } + if c.TLS.GetCert != nil && c.TLS.GetCertHolder == nil { + // cannot determine equality for getCert function that doesn't have non-nil GetCertHolder set as well + return tlsCacheKey{}, false, nil + } k := tlsCacheKey{ insecure: c.TLS.Insecure, @@ -149,6 +164,8 @@ func tlsConfigKey(c *Config) (tlsCacheKey, bool, error) { serverName: c.TLS.ServerName, nextProtos: strings.Join(c.TLS.NextProtos, ","), disableCompression: c.DisableCompression, + getCert: c.TLS.GetCertHolder, + dial: c.DialHolder, } if c.TLS.ReloadTLSFiles { diff --git a/transport/cache_test.go b/transport/cache_test.go index c6d06fcab3..87d070bb01 100644 --- a/transport/cache_test.go +++ b/transport/cache_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "net" "net/http" + "net/url" "testing" ) @@ -58,16 +59,24 @@ func TestTLSConfigKey(t *testing.T) { t.Errorf("Expected identical cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) continue } + if keyA != (tlsCacheKey{}) { + t.Errorf("Expected empty cache keys for %q and %q, got:\n\t%s\n\t%s", nameA, nameB, keyA, keyB) + continue + } } } // Make sure config fields that affect the tls config affect the cache key dialer := net.Dialer{} getCert := func() (*tls.Certificate, error) { return nil, nil } + getCertHolder := &GetCertHolder{GetCert: getCert} uniqueConfigurations := map[string]*Config{ + "proxy": {Proxy: func(request *http.Request) (*url.URL, error) { return nil, nil }}, "no tls": {}, "dialer": {Dial: dialer.DialContext}, "dialer2": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}, + "dialer3": {Dial: dialer.DialContext, DialHolder: &DialHolder{Dial: dialer.DialContext}}, + "dialer4": {Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, DialHolder: &DialHolder{Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }}}, "insecure": {TLS: TLSConfig{Insecure: true}}, "cadata 1": {TLS: TLSConfig{CAData: []byte{1}}}, "cadata 2": {TLS: TLSConfig{CAData: []byte{2}}}, @@ -128,6 +137,13 @@ func TestTLSConfigKey(t *testing.T) { GetCert: func() (*tls.Certificate, error) { return nil, nil }, }, }, + "getCert3": { + TLS: TLSConfig{ + KeyData: []byte{1}, + GetCert: getCert, + GetCertHolder: getCertHolder, + }, + }, "getCert1, key 2": { TLS: TLSConfig{ KeyData: []byte{2}, diff --git a/transport/config.go b/transport/config.go index 89de798f60..fd853c0b39 100644 --- a/transport/config.go +++ b/transport/config.go @@ -68,7 +68,11 @@ type Config struct { WrapTransport WrapperFunc // Dial specifies the dial function for creating unencrypted TCP connections. + // If specified, this transport will be non-cacheable unless DialHolder is also set. Dial func(ctx context.Context, network, address string) (net.Conn, error) + // DialHolder can be populated to make transport configs cacheable. + // If specified, DialHolder.Dial must be equal to Dial. + DialHolder *DialHolder // Proxy is the proxy func to be used for all requests made by this // transport. If Proxy is nil, http.ProxyFromEnvironment is used. If Proxy @@ -78,6 +82,11 @@ type Config struct { Proxy func(*http.Request) (*url.URL, error) } +// DialHolder is used to make the wrapped function comparable so that it can be used as a map key. +type DialHolder struct { + Dial func(ctx context.Context, network, address string) (net.Conn, error) +} + // ImpersonationConfig has all the available impersonation options type ImpersonationConfig struct { // UserName matches user.Info.GetName() @@ -143,5 +152,15 @@ type TLSConfig struct { // To use only http/1.1, set to ["http/1.1"]. NextProtos []string - GetCert func() (*tls.Certificate, error) // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field. + // Callback that returns a TLS client certificate. CertData, CertFile, KeyData and KeyFile supercede this field. + // If specified, this transport is non-cacheable unless CertHolder is populated. + GetCert func() (*tls.Certificate, error) + // CertHolder can be populated to make transport configs that set GetCert cacheable. + // If set, CertHolder.GetCert must be equal to GetCert. + GetCertHolder *GetCertHolder +} + +// GetCertHolder is used to make the wrapped function comparable so that it can be used as a map key. +type GetCertHolder struct { + GetCert func() (*tls.Certificate, error) } diff --git a/transport/transport.go b/transport/transport.go index b4a7bfa67c..eabfce72d0 100644 --- a/transport/transport.go +++ b/transport/transport.go @@ -24,6 +24,7 @@ import ( "fmt" "io/ioutil" "net/http" + "reflect" "sync" "time" @@ -39,6 +40,10 @@ func New(config *Config) (http.RoundTripper, error) { return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed") } + if !isValidHolders(config) { + return nil, fmt.Errorf("misconfigured holder for dialer or cert callback") + } + var ( rt http.RoundTripper err error @@ -56,6 +61,26 @@ func New(config *Config) (http.RoundTripper, error) { return HTTPWrappersForConfig(config, rt) } +func isValidHolders(config *Config) bool { + if config.TLS.GetCertHolder != nil { + if config.TLS.GetCertHolder.GetCert == nil || + config.TLS.GetCert == nil || + reflect.ValueOf(config.TLS.GetCertHolder.GetCert).Pointer() != reflect.ValueOf(config.TLS.GetCert).Pointer() { + return false + } + } + + if config.DialHolder != nil { + if config.DialHolder.Dial == nil || + config.Dial == nil || + reflect.ValueOf(config.DialHolder.Dial).Pointer() != reflect.ValueOf(config.Dial).Pointer() { + return false + } + } + + return true +} + // TLSConfigFor returns a tls.Config that will provide the transport level security defined // by the provided Config. Will return nil if no transport level security is requested. func TLSConfigFor(c *Config) (*tls.Config, error) { diff --git a/transport/transport_test.go b/transport/transport_test.go index c439c96f81..e0fd2679a5 100644 --- a/transport/transport_test.go +++ b/transport/transport_test.go @@ -21,6 +21,7 @@ import ( "crypto/tls" "errors" "fmt" + "net" "net/http" "testing" ) @@ -94,6 +95,13 @@ stR0Yiw0buV6DL/moUO0HIM9Bjh96HJp+LxiIS6UCdIhMPp5HoQa ) func TestNew(t *testing.T) { + globalGetCert := &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + } + globalDial := &DialHolder{ + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + } + testCases := map[string]struct { Config *Config Err bool @@ -255,6 +263,144 @@ func TestNew(t *testing.T) { }, }, }, + "nil holders and nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: nil, + DialHolder: nil, + }, + Err: false, + TLS: false, + TLSCert: false, + TLSErr: false, + Default: true, + Insecure: false, + DefaultRoots: false, + }, + "nil holders and non-nil regular get cert": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + GetCertHolder: nil, + }, + Dial: nil, + DialHolder: nil, + }, + Err: false, + TLS: true, + TLSCert: true, + TLSErr: false, + Default: false, + Insecure: false, + DefaultRoots: true, + }, + "nil holders and non-nil regular dial": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + DialHolder: nil, + }, + Err: false, + TLS: true, + TLSCert: false, + TLSErr: false, + Default: false, + Insecure: false, + DefaultRoots: true, + }, + "non-nil dial holder and nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: nil, + DialHolder: &DialHolder{}, + }, + Err: true, + }, + "non-nil cert holder and nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: &GetCertHolder{}, + }, + Dial: nil, + DialHolder: nil, + }, + Err: true, + }, + "non-nil dial holder and non-nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + DialHolder: &DialHolder{}, + }, + Err: true, + }, + "non-nil cert holder and non-nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + GetCertHolder: &GetCertHolder{}, + }, + Dial: nil, + DialHolder: nil, + }, + Err: true, + }, + "non-nil dial holder+internal and non-nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: nil, + GetCertHolder: nil, + }, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + DialHolder: &DialHolder{ + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { return nil, nil }, + }, + }, + Err: true, + }, + "non-nil cert holder+internal and non-nil regular": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + GetCertHolder: &GetCertHolder{ + GetCert: func() (*tls.Certificate, error) { return nil, nil }, + }, + }, + Dial: nil, + DialHolder: nil, + }, + Err: true, + }, + "non-nil holders+internal and non-nil regular with correct address": { + Config: &Config{ + TLS: TLSConfig{ + GetCert: globalGetCert.GetCert, + GetCertHolder: globalGetCert, + }, + Dial: globalDial.Dial, + DialHolder: globalDial, + }, + Err: false, + TLS: true, + TLSCert: true, + TLSErr: false, + Default: false, + Insecure: false, + DefaultRoots: true, + }, } for k, testCase := range testCases { t.Run(k, func(t *testing.T) {