Skip to content

Commit

Permalink
Replace NewCredFromCert with NewCredFromCertChain (#391)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Feb 23, 2023
1 parent bddac96 commit 7d2056c
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 60 deletions.
15 changes: 4 additions & 11 deletions apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ type AuthResult = base.AuthResult

type Account = shared.Account

// CertFromPEM converts a PEM file (.pem or .key) for use with NewCredFromCert(). The file
// CertFromPEM converts a PEM file (.pem or .key) for use with [NewCredFromCert]. The file
// must contain the public certificate and the private key. If a PEM block is encrypted and
// password is not an empty string, it attempts to decrypt the PEM blocks using the password.
// Multiple certs are due to certificate chaining for use cases like TLS that sign from root to leaf.
Expand Down Expand Up @@ -185,16 +185,9 @@ func NewCredFromAssertionCallback(callback func(context.Context, AssertionReques
return Credential{assertionCallback: callback}
}

// NewCredFromCert creates a Credential from an x509.Certificate and an RSA private key.
// CertFromPEM() can be used to get these values from a PEM file.
func NewCredFromCert(cert *x509.Certificate, key crypto.PrivateKey) Credential {
cred, _ := NewCredFromCertChain([]*x509.Certificate{cert}, key)
return cred
}

// NewCredFromCertChain creates a Credential from a chain of x509.Certificates and an RSA private key
// as returned by CertFromPEM().
func NewCredFromCertChain(certs []*x509.Certificate, key crypto.PrivateKey) (Credential, error) {
// NewCredFromCert creates a Credential from a certificate or chain of certificates and an RSA private key
// as returned by [CertFromPEM].
func NewCredFromCert(certs []*x509.Certificate, key crypto.PrivateKey) (Credential, error) {
cred := Credential{key: key}
k, ok := key.(*rsa.PrivateKey)
if !ok {
Expand Down
30 changes: 15 additions & 15 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -366,20 +366,9 @@ func TestAcquireTokenSilentTenants(t *testing.T) {
}

func TestInvalidCredential(t *testing.T) {
data, err := os.ReadFile("../testdata/test-cert.pem")
if err != nil {
t.Fatal(err)
}
certs, key, err := CertFromPEM(data, "")
if err != nil {
t.Fatal(err)
}
for _, cred := range []Credential{
{},
NewCredFromAssertionCallback(nil),
NewCredFromCert(nil, nil),
NewCredFromCert(certs[0], nil),
NewCredFromCert(nil, key),
} {
t.Run("", func(t *testing.T) {
_, err := New(fakeClientID, cred)
Expand All @@ -390,7 +379,7 @@ func TestInvalidCredential(t *testing.T) {
}
}

func TestNewCredFromCertChain(t *testing.T) {
func TestNewCredFromCert(t *testing.T) {
for _, file := range []struct {
path string
numCerts int
Expand Down Expand Up @@ -424,7 +413,7 @@ func TestNewCredFromCertChain(t *testing.T) {
t.Fatal("expected an RSA private key")
}
verifyingKey := &k.PublicKey
cred, err := NewCredFromCertChain(certs, key)
cred, err := NewCredFromCert(certs, key)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -507,7 +496,7 @@ func TestNewCredFromCertChain(t *testing.T) {
}
}

func TestNewCredFromCertChainError(t *testing.T) {
func TestNewCredFromCertError(t *testing.T) {
data, err := os.ReadFile("../testdata/test-cert.pem")
if err != nil {
t.Fatal(err)
Expand All @@ -529,12 +518,23 @@ func TestNewCredFromCertChainError(t *testing.T) {
{[]*x509.Certificate{nil}, key},
} {
t.Run("", func(t *testing.T) {
_, err := NewCredFromCertChain(test.certs, test.key)
_, err := NewCredFromCert(test.certs, test.key)
if err == nil {
t.Fatal("expected an error")
}
})
}

// the key in this file doesn't match the cert loaded above
if data, err = os.ReadFile("../testdata/test-cert-chain.pem"); err != nil {
t.Fatal(err)
}
if _, key, err = CertFromPEM(data, ""); err != nil {
t.Fatal(err)
}
if _, err = NewCredFromCert(certs, key); err == nil {
t.Fatal("expected an error because key doesn't match certs")
}
}

func TestNewCredFromTokenProvider(t *testing.T) {
Expand Down
29 changes: 3 additions & 26 deletions apps/confidential/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,9 @@ func ExampleNewCredFromCert_pem() {
log.Fatal(err)
}

// PEM files can have multiple certs. This is usually for certificate chaining where roots
// sign to leafs. Useful for TLS, not for this use case.
if len(certs) > 1 {
log.Fatal("too many certificates in PEM file")
}

cred := confidential.NewCredFromCert(certs[0], priv)
fmt.Println(cred) // Simply here so cred is used, otherwise won't compile.
}

func ExampleNewCredFromCertChain() {
b, err := os.ReadFile("key.pem")
cred, err := confidential.NewCredFromCert(certs, priv)
if err != nil {
// TODO: handle error
}

// CertFromPEM loads certificates and a private key from the PEM content. If
// the content is encrypted, the second argument must be the password.
certs, priv, err := confidential.CertFromPEM(b, "")
if err != nil {
// TODO: handle error
}

cred, err := confidential.NewCredFromCertChain(certs, priv)
if err != nil {
// TODO: handle error
log.Fatal(err)
}
_ = cred
fmt.Println(cred) // Simply here so cred is used, otherwise won't compile.
}
9 changes: 1 addition & 8 deletions apps/tests/devapps/client_certificate_sample.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,7 @@ func acquireTokenClientCertificate() {
if err != nil {
log.Fatal(err)
}

// PEM files can have multiple certs. This is usually for certificate chaining where roots
// sign to leafs. Useful for TLS, not for this use case.
if len(certs) > 1 {
log.Fatal("too many certificates in PEM file")
}

cred := confidential.NewCredFromCert(certs[0], privateKey)
cred, err := confidential.NewCredFromCert(certs, privateKey)
if err != nil {
log.Fatal(err)
}
Expand Down

0 comments on commit 7d2056c

Please sign in to comment.