Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add custom cert support #250

Merged
merged 5 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/v1beta2/ocicluster_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ type OCIAvailabilityDomain struct {
// ClientOverrides contains information about client host url overrides.
type ClientOverrides struct {

// CertOverride is a secret that contains information about a cert override used by all the OCI SDK clients.
// The secret must contain data with a `cert`property.
//
// +optional
// +nullable
CertOverride *corev1.SecretReference `json:"certOverride,omitempty"`

// ComputeClientUrl allows the default compute SDK client URL to be changed.
//
// +optional
Expand Down
5 changes: 5 additions & 0 deletions api/v1beta2/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

141 changes: 127 additions & 14 deletions cloud/scope/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ limitations under the License.
package scope

import (
"crypto/tls"
"crypto/x509"
"net/http"
"sync"

Expand Down Expand Up @@ -60,6 +62,7 @@ type ClientProvider struct {
ociClientsLock *sync.RWMutex
ociAuthConfigProvider common.ConfigurationProvider
ociClientOverrides *v1beta2.ClientOverrides
certOverride *x509.CertPool
}

// ClientProviderParams is the params struct for NewClientProvider
Expand All @@ -69,6 +72,9 @@ type ClientProviderParams struct {

// ClientOverrides contains information about client host url overrides.
ClientOverrides *v1beta2.ClientOverrides

// CertOverride a x509 CertPool to use as an override for client TLSClientConfig
CertOverride *x509.CertPool
}

// NewClientProvider builds the ClientProvider with a client for the given region
Expand All @@ -81,6 +87,7 @@ func NewClientProvider(params ClientProviderParams) (*ClientProvider, error) {

provider := ClientProvider{
Logger: &log,
certOverride: params.CertOverride,
ociAuthConfigProvider: params.OciAuthConfigProvider,
ociClients: map[string]OCIClients{},
ociClientsLock: new(sync.RWMutex),
Expand Down Expand Up @@ -138,7 +145,7 @@ func (c *ClientProvider) createClients(region string) (OCIClients, error) {
if err != nil {
return OCIClients{}, err
}
identityClient, err := c.createIdentityClient(region, c.ociAuthConfigProvider, c.Logger)
identityClt, err := c.createIdentityClient(region, c.ociAuthConfigProvider, c.Logger)
joekr marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return OCIClients{}, err
}
Expand All @@ -150,7 +157,7 @@ func (c *ClientProvider) createClients(region string) (OCIClients, error) {
if err != nil {
return OCIClients{}, err
}
containerEngineClient, err := c.createContainerEngineClient(region, c.ociAuthConfigProvider, c.Logger)
containerEngineClt, err := c.createContainerEngineClient(region, c.ociAuthConfigProvider, c.Logger)
if err != nil {
return OCIClients{}, err
}
Expand All @@ -167,10 +174,10 @@ func (c *ClientProvider) createClients(region string) (OCIClients, error) {
VCNClient: vcnClient,
NetworkLoadBalancerClient: nlbClient,
LoadBalancerClient: lbClient,
IdentityClient: identityClient,
IdentityClient: identityClt,
ComputeClient: computeClient,
ComputeManagementClient: computeManagementClient,
ContainerEngineClient: containerEngineClient,
ContainerEngineClient: containerEngineClt,
BaseClient: baseClient,
}, err
}
Expand All @@ -182,6 +189,20 @@ func (c *ClientProvider) createVncClient(region string, ociAuthConfigProvider co
return nil, err
}
vcnClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := vcnClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI VCN Client")
return nil, err
}
} else {
return nil, errors.New("The VCN Client dispatcher is not of http.Client type. Can not patch the tls config.")
}

}

if c.ociClientOverrides != nil && c.ociClientOverrides.VCNClientUrl != nil {
vcnClient.Host = *c.ociClientOverrides.VCNClientUrl
}
Expand All @@ -197,6 +218,19 @@ func (c *ClientProvider) createNLbClient(region string, ociAuthConfigProvider co
return nil, err
}
nlbClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := nlbClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI NetworkLoadBalancer Client")
return nil, err
}
} else {
return nil, errors.New("The Network Loadbalancer Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.NetworkLoadBalancerClientUrl != nil {
nlbClient.Host = *c.ociClientOverrides.NetworkLoadBalancerClientUrl
}
Expand All @@ -212,6 +246,19 @@ func (c *ClientProvider) createLBClient(region string, ociAuthConfigProvider com
return nil, err
}
lbClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := lbClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Loadbalancer Client")
return nil, err
}
} else {
return nil, errors.New("The Loadbalancer Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.LoadBalancerClientUrl != nil {
lbClient.Host = *c.ociClientOverrides.LoadBalancerClientUrl
}
Expand All @@ -221,19 +268,31 @@ func (c *ClientProvider) createLBClient(region string, ociAuthConfigProvider com
}

func (c *ClientProvider) createIdentityClient(region string, ociAuthConfigProvider common.ConfigurationProvider, logger *logr.Logger) (*identity.IdentityClient, error) {
identityClient, err := identity.NewIdentityClientWithConfigurationProvider(ociAuthConfigProvider)
identityClt, err := identity.NewIdentityClientWithConfigurationProvider(ociAuthConfigProvider)
if err != nil {
logger.Error(err, "unable to create OCI Identity Client")
return nil, err
}
identityClient.SetRegion(region)
identityClt.SetRegion(region)

if c.certOverride != nil {
if client, ok := identityClt.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Identity Client")
return nil, err
}
} else {
return nil, errors.New("The Identity Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.IdentityClientUrl != nil {
identityClient.Host = *c.ociClientOverrides.IdentityClientUrl
identityClt.Host = *c.ociClientOverrides.IdentityClientUrl
}
identityClient.Interceptor = setVersionHeader()
identityClt.Interceptor = setVersionHeader()

return &identityClient, nil
return &identityClt, nil
}

func (c *ClientProvider) createComputeClient(region string, ociAuthConfigProvider common.ConfigurationProvider, logger *logr.Logger) (*core.ComputeClient, error) {
Expand All @@ -243,6 +302,19 @@ func (c *ClientProvider) createComputeClient(region string, ociAuthConfigProvide
return nil, err
}
computeClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := computeClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Compute Client")
return nil, err
}
} else {
return nil, errors.New("The Compute Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.ComputeClientUrl != nil {
computeClient.Host = *c.ociClientOverrides.ComputeClientUrl
}
Expand All @@ -258,6 +330,19 @@ func (c *ClientProvider) createComputeManagementClient(region string, ociAuthCon
return nil, err
}
computeManagementClient.SetRegion(region)

if c.certOverride != nil {
if client, ok := computeManagementClient.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Compute Management Client")
return nil, err
}
} else {
return nil, errors.New("The Compute Management Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.ComputeManagementClientUrl != nil {
computeManagementClient.Host = *c.ociClientOverrides.ComputeManagementClientUrl
}
Expand All @@ -267,18 +352,31 @@ func (c *ClientProvider) createComputeManagementClient(region string, ociAuthCon
}

func (c *ClientProvider) createContainerEngineClient(region string, ociAuthConfigProvider common.ConfigurationProvider, logger *logr.Logger) (*containerengine.ContainerEngineClient, error) {
containerEngineClient, err := containerengine.NewContainerEngineClientWithConfigurationProvider(ociAuthConfigProvider)
containerEngineClt, err := containerengine.NewContainerEngineClientWithConfigurationProvider(ociAuthConfigProvider)
if err != nil {
logger.Error(err, "unable to create OCI Container Engine Client")
return nil, err
}
containerEngineClient.SetRegion(region)
containerEngineClt.SetRegion(region)

if c.certOverride != nil {
if client, ok := containerEngineClt.HTTPClient.(*http.Client); ok {
err = c.setCerts(client)
if err != nil {
logger.Error(err, "unable to create OCI Container Engine Client")
return nil, err
}
} else {
return nil, errors.New("The Container Engine Client dispatcher is not of http.Client type. Can not patch the tls config.")
}
}

if c.ociClientOverrides != nil && c.ociClientOverrides.ContainerEngineClientUrl != nil {
containerEngineClient.Host = *c.ociClientOverrides.ContainerEngineClientUrl
containerEngineClt.Host = *c.ociClientOverrides.ContainerEngineClientUrl
}
containerEngineClient.Interceptor = setVersionHeader()
containerEngineClt.Interceptor = setVersionHeader()

return &containerEngineClient, nil
return &containerEngineClt, nil
}

func (c *ClientProvider) createBaseClient(region string, ociAuthConfigProvider common.ConfigurationProvider, logger *logr.Logger) (base.BaseClient, error) {
Expand All @@ -296,3 +394,18 @@ func setVersionHeader() func(request *http.Request) error {
return nil
}
}

// setCerts updates the client TLSClientConfig with the ClientProvider certOverride
func (c *ClientProvider) setCerts(client *http.Client) error {
tr := client.Transport.(*http.Transport).Clone()
if tr.TLSClientConfig != nil {
tr.TLSClientConfig.RootCAs = c.certOverride
} else {
tr.TLSClientConfig = &tls.Config{
RootCAs: c.certOverride,
}
}
client.Transport = tr

return nil
}
16 changes: 7 additions & 9 deletions cloud/scope/clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ func TestClients_NewClientProvider(t *testing.T) {
}

clientProvider, err := NewClientProvider(ClientProviderParams{
ociAuthConfigProvider,
nil})
OciAuthConfigProvider: ociAuthConfigProvider})
if err != nil {
t.Errorf("Expected %v to equal nil", err)
}
Expand Down Expand Up @@ -73,8 +72,8 @@ func TestClients_NewClientProviderWithClientOverrides(t *testing.T) {
}

clientProvider, err := NewClientProvider(ClientProviderParams{
ociAuthConfigProvider,
clientOverrides})
OciAuthConfigProvider: ociAuthConfigProvider,
ClientOverrides: clientOverrides})
if err != nil {
t.Errorf("Expected error:%v to not equal nil", err)
}
Expand Down Expand Up @@ -113,8 +112,8 @@ func TestClients_NewClientProviderWithMissingOverrides(t *testing.T) {
}

clientProvider, err := NewClientProvider(ClientProviderParams{
ociAuthConfigProvider,
clientOverrides})
OciAuthConfigProvider: ociAuthConfigProvider,
ClientOverrides: clientOverrides})
if err != nil {
t.Errorf("Expected error:%v to not equal nil", err)
}
Expand All @@ -129,7 +128,7 @@ func TestClients_NewClientProviderWithMissingOverrides(t *testing.T) {
}

func TestClients_NewClientProviderWithBadAuthConfig(t *testing.T) {
clientProvider, err := NewClientProvider(ClientProviderParams{nil, nil})
clientProvider, err := NewClientProvider(ClientProviderParams{})
if err == nil {
t.Errorf("Expected error:%v to not equal nil", err)
}
Expand Down Expand Up @@ -211,8 +210,7 @@ func TestClients_GetAuthProvider(t *testing.T) {
}

clientProvider, err := NewClientProvider(ClientProviderParams{
ociAuthConfigProvider,
nil})
OciAuthConfigProvider: ociAuthConfigProvider})
if err != nil {
t.Errorf("Expected %v to equal nil", err)
}
Expand Down
Loading