Skip to content

Commit

Permalink
Allow Azure ClientCertificate authentication
Browse files Browse the repository at this point in the history
Signed-off-by: Hidde Beydals <hello@hidde.co>
  • Loading branch information
hiddeco committed Mar 3, 2022
1 parent 35594b0 commit cb81050
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 14 deletions.
44 changes: 30 additions & 14 deletions pkg/azure/blob.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ var (
)

const (
resourceIDField = "resourceId"
clientIDField = "clientId"
tenantIDField = "tenantId"
clientSecretField = "clientSecret"
accountKeyField = "accountKey"
resourceIDField = "resourceId"
clientIDField = "clientId"
tenantIDField = "tenantId"
clientSecretField = "clientSecret"
clientCertificateField = "clientCertificate"
clientCertificatePasswordField = "clientCertificatePassword"
accountKeyField = "accountKey"
)

// BlobClient is a minimal Azure Blob client for fetching objects.
Expand All @@ -62,13 +64,17 @@ type BlobClient struct {
// order:
//
// - azidentity.ManagedIdentityCredential for a Resource ID, when a
// resourceIDField is found.
// - azidentity.ManagedIdentityCredential for a User ID, when a clientIDField
// but no tenantIDField found.
// - azidentity.ClientSecretCredential when a tenantIDField, clientIDField and
// clientSecretField are found.
// - azblob.SharedKeyCredential when an accountKeyField is found. The Account
// Name is extracted from the endpoint specified on the Bucket object.
// `resourceId` field is found.
// - azidentity.ManagedIdentityCredential for a User ID, when a `clientId`
// field but no `tenantId` is found.
// - azidentity.ClientCertificateCredential when `tenantId`,
// `clientCertificate` (and optionally `clientCertificatePassword`) fields
// are found.
// - azidentity.ClientSecretCredential when `tenantId`, `clientId` and
// `clientSecret` fields are found.
// - azblob.SharedKeyCredential when an `accountKey` field is found.
// The account name is extracted from the endpoint specified on the Bucket
// object.
//
// If no credentials are found, a simple client without credentials is
// returned.
Expand Down Expand Up @@ -119,6 +125,9 @@ func ValidateSecret(secret *corev1.Secret) error {
if _, hasClientSecret := secret.Data[clientSecretField]; hasClientSecret {
valid = true
}
if _, hasClientCertificate := secret.Data[clientCertificateField]; hasClientCertificate {
valid = true
}
}
}
if _, hasResourceID := secret.Data[resourceIDField]; hasResourceID {
Expand All @@ -132,8 +141,8 @@ func ValidateSecret(secret *corev1.Secret) error {
}

if !valid {
return fmt.Errorf("invalid '%s' secret data: requires a '%s', '%s', or '%s' field, or a combination of '%s', '%s' and '%s'",
secret.Name, resourceIDField, clientIDField, accountKeyField, tenantIDField, clientIDField, clientSecretField)
return fmt.Errorf("invalid '%s' secret data: requires a '%s', '%s', or '%s' field, a combination of '%s', '%s' and '%s', or '%s', '%s' and '%s'",
secret.Name, resourceIDField, clientIDField, accountKeyField, tenantIDField, clientIDField, clientSecretField, tenantIDField, clientIDField, clientCertificateField)
}
return nil
}
Expand Down Expand Up @@ -275,6 +284,13 @@ func tokenCredentialFromSecret(secret *corev1.Secret) (azcore.TokenCredential, e
ID: azidentity.ClientID(clientID),
})
}
if clientCertificate, hasClientCertificate := secret.Data[clientCertificateField]; hasClientCertificate {
certs, key, err := azidentity.ParseCertificates(clientCertificate, secret.Data[clientCertificatePasswordField])
if err != nil {
return nil, fmt.Errorf("failed to parse client certificates: %w", err)
}
return azidentity.NewClientCertificateCredential(string(tenantID), string(clientID), certs, key, nil)
}
if clientSecret, hasClientSecret := secret.Data[clientSecretField]; hasClientSecret {
return azidentity.NewClientSecretCredential(string(tenantID), string(clientID), string(clientSecret), nil)
}
Expand Down
61 changes: 61 additions & 0 deletions pkg/azure/blob_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,14 @@ limitations under the License.
package azure

import (
"bytes"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"math/big"
"testing"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
Expand Down Expand Up @@ -50,6 +56,16 @@ func TestValidateSecret(t *testing.T) {
},
},
},
{
name: "valid ServicePrincipal Certificate Secret",
secret: &corev1.Secret{
Data: map[string][]byte{
tenantIDField: []byte("some-tenant-id-"),
clientIDField: []byte("some-client-id-"),
clientCertificateField: []byte("some-certificate"),
},
},
},
{
name: "valid ServicePrincipal Secret",
secret: &corev1.Secret{
Expand Down Expand Up @@ -192,6 +208,17 @@ func Test_tokenCredentialFromSecret(t *testing.T) {
},
want: &azidentity.ManagedIdentityCredential{},
},
{
name: "with TenantID, ClientID and ClientCertificate fields",
secret: &corev1.Secret{
Data: map[string][]byte{
clientIDField: []byte("client-id"),
tenantIDField: []byte("tenant-id"),
clientCertificateField: validTls(t),
},
},
want: &azidentity.ClientCertificateCredential{},
},
{
name: "with TenantID, ClientID and ClientSecret fields",
secret: &corev1.Secret{
Expand Down Expand Up @@ -316,3 +343,37 @@ func Test_extractAccountNameFromEndpoint1(t *testing.T) {
func endpointURL(accountName string) string {
return fmt.Sprintf("https://%s.blob.core.windows.net", accountName)
}

func validTls(t *testing.T) []byte {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal("Private key cannot be created.", err.Error())
}

out := bytes.NewBuffer(nil)

var privateKey = &pem.Block{
Type: "PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}
if err = pem.Encode(out, privateKey); err != nil {
t.Fatal("Private key cannot be PEM encoded.", err.Error())
}

certTemplate := x509.Certificate{
SerialNumber: big.NewInt(1337),
}
cert, err := x509.CreateCertificate(rand.Reader, &certTemplate, &certTemplate, &key.PublicKey, key)
if err != nil {
t.Fatal("Certificate cannot be created.", err.Error())
}
var certificate = &pem.Block{
Type: "CERTIFICATE",
Bytes: cert,
}
if err = pem.Encode(out, certificate); err != nil {
t.Fatal("Certificate cannot be PEM encoded.", err.Error())
}

return out.Bytes()
}

0 comments on commit cb81050

Please sign in to comment.