Skip to content

Commit

Permalink
feat: Add custom cert support (#250)
Browse files Browse the repository at this point in the history
add support for client cert override per cluster via a secret
  • Loading branch information
joekr authored May 10, 2023
1 parent e09e016 commit a658119
Show file tree
Hide file tree
Showing 11 changed files with 417 additions and 29 deletions.
7 changes: 7 additions & 0 deletions api/v1beta2/ocicluster_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ type OCIAvailabilityDomain struct {
// ClientOverrides contains information about client host url overrides.
type ClientOverrides struct {

// CertOverride is a secret that contains information about a cert override used by all the OCI SDK clients.
// The secret must contain data with a `cert`property.
//
// +optional
// +nullable
CertOverride *corev1.SecretReference `json:"certOverride,omitempty"`

// ComputeClientUrl allows the default compute SDK client URL to be changed.
//
// +optional
Expand Down
5 changes: 5 additions & 0 deletions api/v1beta2/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

141 changes: 127 additions & 14 deletions cloud/scope/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package scope

import (
"crypto/tls"
"crypto/x509"
"net/http"
"sync"

Expand Down Expand Up @@ -60,6 +62,7 @@ type ClientProvider struct {
ociClientsLock *sync.RWMutex
ociAuthConfigProvider common.ConfigurationProvider
ociClientOverrides *v1beta2.ClientOverrides
certOverride *x509.CertPool
}

// ClientProviderParams is the params struct for NewClientProvider
Expand All @@ -69,6 +72,9 @@ type ClientProviderParams struct {

// ClientOverrides contains information about client host url overrides.
ClientOverrides *v1beta2.ClientOverrides

// CertOverride a x509 CertPool to use as an override for client TLSClientConfig
CertOverride *x509.CertPool
}

// NewClientProvider builds the ClientProvider with a client for the given region
Expand All @@ -81,6 +87,7 @@ func NewClientProvider(params ClientProviderParams) (*ClientProvider, error) {

provider := ClientProvider{
Logger: &log,
certOverride: params.CertOverride,
ociAuthConfigProvider: params.OciAuthConfigProvider,
ociClients: map[string]OCIClients{},
ociClientsLock: new(sync.RWMutex),
Expand Down Expand Up @@ -138,7 +145,7 @@ func (c *ClientProvider) createClients(region string) (OCIClients, error) {
if err != nil {
return OCIClients{}, err
}
identityClient, err := c.createIdentityClient(region, c.ociAuthConfigProvider, c.Logger)
identityClt, err := c.createIdentityClient(region, c.ociAuthConfigProvider, c.Logger)
if err != nil {
return OCIClients{}, err
}
Expand All @@ -150,7 +157,7 @@ func (c *ClientProvider) createClients(region string) (OCIClients, error) {
if err != nil {
return OCIClients{}, err
}
containerEngineClient, err := c.createContainerEngineClient(region, c.ociAuthConfigProvider, c.Logger)
containerEngineClt, err := c.createContainerEngineClient(region, c.ociAuthConfigProvider, c.Logger)
if err != nil {
return OCIClients{}, err
}
Expand All @@ -167,10 +174,10 @@ func (c *ClientProvider) createClients(region string) (OCIClients, error) {
VCNClient: vcnClient,
NetworkLoadBalancerClient: nlbClient,
LoadBalancerClient: lbClient,
IdentityClient: identityClient,
IdentityClient: identityClt,
ComputeClient: computeClient,
ComputeManagementClient: computeManagementClient,
ContainerEngineClient: containerEngineClient,
ContainerEngineClient: containerEngineClt,
BaseClient: baseClient,
}, err
}
Expand All @@ -182,6 +189,20 @@ func (c *ClientProvider) createVncClient(region string, ociAuthConfigProvider co
return nil, err
}
vcnClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := vcnClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI VCN Client")
return nil, err
}
} else {
return nil, errors.New("The VCN Client dispatcher is not of http.Client type. Can not patch the tls config.")
}

}

if c.ociClientOverrides != nil && c.ociClientOverrides.VCNClientUrl != nil {
vcnClient.Host = *c.ociClientOverrides.VCNClientUrl
}
Expand All @@ -197,6 +218,19 @@ func (c *ClientProvider) createNLbClient(region string, ociAuthConfigProvider co
return nil, err
}
nlbClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := nlbClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI NetworkLoadBalancer Client")
return nil, err
}
} else {
return nil, errors.New("The Network Loadbalancer Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.NetworkLoadBalancerClientUrl != nil {
nlbClient.Host = *c.ociClientOverrides.NetworkLoadBalancerClientUrl
}
Expand All @@ -212,6 +246,19 @@ func (c *ClientProvider) createLBClient(region string, ociAuthConfigProvider com
return nil, err
}
lbClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := lbClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Loadbalancer Client")
return nil, err
}
} else {
return nil, errors.New("The Loadbalancer Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.LoadBalancerClientUrl != nil {
lbClient.Host = *c.ociClientOverrides.LoadBalancerClientUrl
}
Expand All @@ -221,19 +268,31 @@ func (c *ClientProvider) createLBClient(region string, ociAuthConfigProvider com
}

func (c *ClientProvider) createIdentityClient(region string, ociAuthConfigProvider common.ConfigurationProvider, logger *logr.Logger) (*identity.IdentityClient, error) {
identityClient, err := identity.NewIdentityClientWithConfigurationProvider(ociAuthConfigProvider)
identityClt, err := identity.NewIdentityClientWithConfigurationProvider(ociAuthConfigProvider)
if err != nil {
logger.Error(err, "unable to create OCI Identity Client")
return nil, err
}
identityClient.SetRegion(region)
identityClt.SetRegion(region)

if c.certOverride != nil {
if client, ok := identityClt.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Identity Client")
return nil, err
}
} else {
return nil, errors.New("The Identity Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.IdentityClientUrl != nil {
identityClient.Host = *c.ociClientOverrides.IdentityClientUrl
identityClt.Host = *c.ociClientOverrides.IdentityClientUrl
}
identityClient.Interceptor = setVersionHeader()
identityClt.Interceptor = setVersionHeader()

return &identityClient, nil
return &identityClt, nil
}

func (c *ClientProvider) createComputeClient(region string, ociAuthConfigProvider common.ConfigurationProvider, logger *logr.Logger) (*core.ComputeClient, error) {
Expand All @@ -243,6 +302,19 @@ func (c *ClientProvider) createComputeClient(region string, ociAuthConfigProvide
return nil, err
}
computeClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := computeClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Compute Client")
return nil, err
}
} else {
return nil, errors.New("The Compute Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.ComputeClientUrl != nil {
computeClient.Host = *c.ociClientOverrides.ComputeClientUrl
}
Expand All @@ -258,6 +330,19 @@ func (c *ClientProvider) createComputeManagementClient(region string, ociAuthCon
return nil, err
}
computeManagementClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := computeManagementClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Compute Management Client")
return nil, err
}
} else {
return nil, errors.New("The Compute Management Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.ComputeManagementClientUrl != nil {
computeManagementClient.Host = *c.ociClientOverrides.ComputeManagementClientUrl
}
Expand All @@ -267,18 +352,31 @@ func (c *ClientProvider) createComputeManagementClient(region string, ociAuthCon
}

func (c *ClientProvider) createContainerEngineClient(region string, ociAuthConfigProvider common.ConfigurationProvider, logger *logr.Logger) (*containerengine.ContainerEngineClient, error) {
containerEngineClient, err := containerengine.NewContainerEngineClientWithConfigurationProvider(ociAuthConfigProvider)
containerEngineClt, err := containerengine.NewContainerEngineClientWithConfigurationProvider(ociAuthConfigProvider)
if err != nil {
logger.Error(err, "unable to create OCI Container Engine Client")
return nil, err
}
containerEngineClient.SetRegion(region)
containerEngineClt.SetRegion(region)

if c.certOverride != nil {
if client, ok := containerEngineClt.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Container Engine Client")
return nil, err
}
} else {
return nil, errors.New("The Container Engine Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.ContainerEngineClientUrl != nil {
containerEngineClient.Host = *c.ociClientOverrides.ContainerEngineClientUrl
containerEngineClt.Host = *c.ociClientOverrides.ContainerEngineClientUrl
}
containerEngineClient.Interceptor = setVersionHeader()
containerEngineClt.Interceptor = setVersionHeader()

return &containerEngineClient, nil
return &containerEngineClt, nil
}

func (c *ClientProvider) createBaseClient(region string, ociAuthConfigProvider common.ConfigurationProvider, logger *logr.Logger) (base.BaseClient, error) {
Expand All @@ -296,3 +394,18 @@ func setVersionHeader() func(request *http.Request) error {
return nil
}
}

// setCerts updates the client TLSClientConfig with the ClientProvider certOverride
func (c *ClientProvider) setCerts(client *http.Client) error {
tr := client.Transport.(*http.Transport).Clone()
if tr.TLSClientConfig != nil {
tr.TLSClientConfig.RootCAs = c.certOverride
} else {
tr.TLSClientConfig = &tls.Config{
RootCAs: c.certOverride,
}
}
client.Transport = tr

return nil
}
16 changes: 7 additions & 9 deletions cloud/scope/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ func TestClients_NewClientProvider(t *testing.T) {
}

clientProvider, err := NewClientProvider(ClientProviderParams{
ociAuthConfigProvider,
nil})
OciAuthConfigProvider: ociAuthConfigProvider})
if err != nil {
t.Errorf("Expected %v to equal nil", err)
}
Expand Down Expand Up @@ -73,8 +72,8 @@ func TestClients_NewClientProviderWithClientOverrides(t *testing.T) {
}

clientProvider, err := NewClientProvider(ClientProviderParams{
ociAuthConfigProvider,
clientOverrides})
OciAuthConfigProvider: ociAuthConfigProvider,
ClientOverrides: clientOverrides})
if err != nil {
t.Errorf("Expected error:%v to not equal nil", err)
}
Expand Down Expand Up @@ -113,8 +112,8 @@ func TestClients_NewClientProviderWithMissingOverrides(t *testing.T) {
}

clientProvider, err := NewClientProvider(ClientProviderParams{
ociAuthConfigProvider,
clientOverrides})
OciAuthConfigProvider: ociAuthConfigProvider,
ClientOverrides: clientOverrides})
if err != nil {
t.Errorf("Expected error:%v to not equal nil", err)
}
Expand All @@ -129,7 +128,7 @@ func TestClients_NewClientProviderWithMissingOverrides(t *testing.T) {
}

func TestClients_NewClientProviderWithBadAuthConfig(t *testing.T) {
clientProvider, err := NewClientProvider(ClientProviderParams{nil, nil})
clientProvider, err := NewClientProvider(ClientProviderParams{})
if err == nil {
t.Errorf("Expected error:%v to not equal nil", err)
}
Expand Down Expand Up @@ -211,8 +210,7 @@ func TestClients_GetAuthProvider(t *testing.T) {
}

clientProvider, err := NewClientProvider(ClientProviderParams{
ociAuthConfigProvider,
nil})
OciAuthConfigProvider: ociAuthConfigProvider})
if err != nil {
t.Errorf("Expected %v to equal nil", err)
}
Expand Down
Loading

0 comments on commit a658119

Please sign in to comment.