Skip to content

Commit

Permalink
Merge pull request #204 from smallstep/permanent-identifier
Browse files Browse the repository at this point in the history
Add permanent identifier in the CreateAttestation response
  • Loading branch information
maraino authored Mar 25, 2023
2 parents 0ac014f + 797be4d commit 54a1c86
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 67 deletions.
2 changes: 1 addition & 1 deletion internal/bcrypt_pbkdf/bcrypt_pbkdf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,6 @@ func BenchmarkKey(b *testing.B) {
pass := []byte("password")
salt := []byte("salt")
for i := 0; i < b.N; i++ {
Key(pass, salt, 10, 32)
_, _ = Key(pass, salt, 10, 32)
}
}
7 changes: 4 additions & 3 deletions kms/apiv1/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ type CreateAttestationRequest struct {
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release.
type CreateAttestationResponse struct {
Certificate *x509.Certificate
CertificateChain []*x509.Certificate
PublicKey crypto.PublicKey
Certificate *x509.Certificate
CertificateChain []*x509.Certificate
PublicKey crypto.PublicKey
PermanentIdentifier string
}
2 changes: 1 addition & 1 deletion kms/pkcs11/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ func benchmarkSign(b *testing.B, signer crypto.Signer, opts crypto.SignerOpts) {
digest := h.Sum(nil)
b.ResetTimer()
for i := 0; i < b.N; i++ {
signer.Sign(rand.Reader, digest, opts)
_, _ = signer.Sign(rand.Reader, digest, opts)
}
b.StopTimer()
}
Expand Down
8 changes: 6 additions & 2 deletions kms/pkcs11/pkcs11.go
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ func (k *PKCS11) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
return errors.Wrap(err, "storeCertificate failed")
}
if req.Extractable {
template.Set(crypto11.CkaExtractable, true)
if err := template.Set(crypto11.CkaExtractable, true); err != nil {
return errors.Wrap(err, "storeCertificate failed")
}
}
if err := k.p11.ImportCertificateWithAttributes(template, req.Certificate); err != nil {
return errors.Wrap(err, "storeCertificate failed")
Expand Down Expand Up @@ -326,7 +328,9 @@ func generateKey(ctx P11, req *apiv1.CreateKeyRequest) (crypto11.Signer, error)
}
private := public.Copy()
if req.Extractable {
private.Set(crypto11.CkaExtractable, true)
if err := private.Set(crypto11.CkaExtractable, true); err != nil {
return nil, err
}
}

bits := req.Bits
Expand Down
10 changes: 5 additions & 5 deletions kms/pkcs11/pkcs11_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func TestPKCS11_CreateKey(t *testing.T) {
k := setupPKCS11(t)

// Make sure to delete the created key
k.DeleteKey(testObject)
_ = k.DeleteKey(testObject)

type args struct {
req *apiv1.CreateKeyRequest
Expand Down Expand Up @@ -553,13 +553,13 @@ func TestPKCS11_CreateDecrypter(t *testing.T) {
}

// RSA-OAEP
enc, err = rsa.EncryptOAEP(crypto.SHA256.New(), rand.Reader, pub, data, []byte("label"))
enc, err = rsa.EncryptOAEP(crypto.SHA1.New(), rand.Reader, pub, data, []byte("label"))
if err != nil {
t.Errorf("rsa.EncryptOAEP() error = %v", err)
return
}
dec, err = got.Decrypt(rand.Reader, enc, &rsa.OAEPOptions{
Hash: crypto.SHA256,
Hash: crypto.SHA1,
Label: []byte("label"),
})
if err != nil {
Expand Down Expand Up @@ -653,8 +653,8 @@ func TestPKCS11_StoreCertificate(t *testing.T) {

// Make sure to delete the created certificate
t.Cleanup(func() {
k.DeleteCertificate(testObject)
k.DeleteCertificate(testObjectAlt)
_ = k.DeleteCertificate(testObject)
_ = k.DeleteCertificate(testObjectAlt)
})

type args struct {
Expand Down
2 changes: 1 addition & 1 deletion kms/sshagentkms/sshagentkms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func startOpenSSHAgent(t *testing.T) (client agent.Agent, socket string, cleanup
return ac, socket, func() {
proc, _ := os.FindProcess(pid)
if proc != nil {
proc.Kill()
_ = proc.Kill()
}
conn.Close()
os.RemoveAll(filepath.Dir(socket))
Expand Down
32 changes: 16 additions & 16 deletions kms/yubikey/yubikey.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ func New(ctx context.Context, opts apiv1.Options) (*YubiKey, error) {
// Attempt to locate the yubikey with the given serial.
for _, name := range cards {
if k, err := pivOpen(name); err == nil {
if serialNumber, err := getSerialNumber(k); err == nil && serial == serialNumber {
yk = k
break
if cert, err := k.Attest(piv.SlotAuthentication); err == nil {
if serial == getSerialNumber(cert) {
yk = k
break
}
}
}
}
Expand Down Expand Up @@ -321,9 +323,10 @@ func (k *YubiKey) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1
}

return &apiv1.CreateAttestationResponse{
Certificate: cert,
CertificateChain: []*x509.Certificate{intermediate},
PublicKey: cert.PublicKey,
Certificate: cert,
CertificateChain: []*x509.Certificate{intermediate},
PublicKey: cert.PublicKey,
PermanentIdentifier: getSerialNumber(cert),
}, nil
}

Expand Down Expand Up @@ -471,22 +474,19 @@ func getPolicies(req *apiv1.CreateKeyRequest) (piv.PINPolicy, piv.TouchPolicy) {
return pin, touch
}

// getSerialNumber gets an attestation certificate on the given key and returns
// the serial number on it.
func getSerialNumber(yk pivKey) (string, error) {
cert, err := yk.Attest(piv.SlotAuthentication)
if err != nil {
return "", err
}
// getSerialNumber returns the serial number from an attestation certificate. It
// will return an empty string if the serial number extension does not exist
// or if it is malformed.
func getSerialNumber(cert *x509.Certificate) string {
for _, ext := range cert.Extensions {
if ext.Id.Equal(oidYubicoSerialNumber) {
var serialNumber int
rest, err := asn1.Unmarshal(ext.Value, &serialNumber)
if err != nil || len(rest) > 0 {
return "", errors.New("error parsing YubiKey serial number")
return ""
}
return strconv.Itoa(serialNumber), nil
return strconv.Itoa(serialNumber)
}
}
return "", errors.New("failed to find YubiKey serial number")
return ""
}
133 changes: 98 additions & 35 deletions kms/yubikey/yubikey_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,25 @@ func (s *stubPivKey) Close() error {
return nil
}

func TestRegister(t *testing.T) {
pCards := pivCards
t.Cleanup(func() {
pivCards = pCards
})

pivCards = func() ([]string, error) {
return []string{"Yubico YubiKey OTP+FIDO+CCID"}, nil
}

fn, ok := apiv1.LoadKeyManagerNewFunc(apiv1.YubiKey)
if !ok {
t.Fatal("YubiKey is not registered")
}
_, _ = fn(context.Background(), apiv1.Options{
Type: "YubiKey", URI: "yubikey:",
})
}

func TestNew(t *testing.T) {
ctx := context.Background()
pOpen := pivOpen
Expand Down Expand Up @@ -956,9 +975,10 @@ func TestYubiKey_CreateAttestation(t *testing.T) {
{"ok", fields{yk, "123456", piv.DefaultManagementKey}, args{&apiv1.CreateAttestationRequest{
Name: "yubikey:slot-id=9a",
}}, &apiv1.CreateAttestationResponse{
Certificate: yk.attestMap[piv.SlotAuthentication],
CertificateChain: []*x509.Certificate{yk.attestCA.Intermediate},
PublicKey: yk.attestMap[piv.SlotAuthentication].PublicKey,
Certificate: yk.attestMap[piv.SlotAuthentication],
CertificateChain: []*x509.Certificate{yk.attestCA.Intermediate},
PublicKey: yk.attestMap[piv.SlotAuthentication].PublicKey,
PermanentIdentifier: "112233",
}, false},
{"fail getSlot", fields{yk, "123456", piv.DefaultManagementKey}, args{&apiv1.CreateAttestationRequest{
Name: "yubikey://:slot-id=9a",
Expand Down Expand Up @@ -1019,61 +1039,104 @@ func TestYubiKey_Close(t *testing.T) {
}

func Test_getSerialNumber(t *testing.T) {
ok := newStubPivKey(t, RSA)

failAttest := newStubPivKey(t, RSA)
delete(failAttest.attestMap, piv.SlotAuthentication)

failParse := newStubPivKey(t, ECDSA)
serialNumer, err := asn1.Marshal("112233")
serialNumber, err := asn1.Marshal(112233)
if err != nil {
t.Fatal(err)
}
attCertParse, err := failParse.attestCA.Sign(&x509.Certificate{
Subject: pkix.Name{CommonName: "attested certificate"},
PublicKey: failParse.attestSigner.Public(),
ExtraExtensions: []pkix.Extension{
{Id: oidYubicoSerialNumber, Value: serialNumer},
},
})
printableSerialNumber, err := asn1.Marshal("112233")
if err != nil {
t.Fatal(err)
}
failMissing := newStubPivKey(t, ECDSA)
attCertMissing, err := failMissing.attestCA.Sign(&x509.Certificate{

yk := newStubPivKey(t, RSA)
okCert := yk.attestMap[piv.SlotAuthentication]
printableCert := &x509.Certificate{
Subject: pkix.Name{CommonName: "attested certificate"},
PublicKey: failMissing.attestSigner.Public(),
})
if err != nil {
t.Fatal(err)
PublicKey: okCert.PublicKey,
Extensions: []pkix.Extension{
{Id: oidYubicoSerialNumber, Value: printableSerialNumber},
},
}
restCert := &x509.Certificate{
Subject: pkix.Name{CommonName: "attested certificate"},
PublicKey: okCert.PublicKey,
Extensions: []pkix.Extension{
{Id: oidYubicoSerialNumber, Value: append(serialNumber, 0)},
},
}
missingCert := &x509.Certificate{
Subject: pkix.Name{CommonName: "attested certificate"},
PublicKey: okCert.PublicKey,
}

type args struct {
cert *x509.Certificate
}
tests := []struct {
name string
args args
want string
}{
{"ok", args{okCert}, "112233"},
{"fail printable", args{printableCert}, ""},
{"fail rest", args{restCert}, ""},
{"fail missing", args{missingCert}, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getSerialNumber(tt.args.cert); got != tt.want {
t.Errorf("getSerialNumber() = %v, want %v", got, tt.want)
}
})
}
}

failParse.attestMap[piv.SlotAuthentication] = attCertParse
failMissing.attestMap[piv.SlotAuthentication] = attCertMissing
func Test_getSignatureAlgorithm(t *testing.T) {
fake := apiv1.SignatureAlgorithm(1000)
t.Cleanup(func() {
delete(signatureAlgorithmMapping, fake)
})
signatureAlgorithmMapping[fake] = "fake"

type args struct {
yk pivKey
alg apiv1.SignatureAlgorithm
bits int
}
tests := []struct {
name string
args args
want string
want piv.Algorithm
wantErr bool
}{
{"ok", args{ok}, "112233", false},
{"fail attest", args{failAttest}, "", true},
{"fail parse", args{failParse}, "", true},
{"fail missing", args{failMissing}, "", true},
{"default", args{apiv1.UnspecifiedSignAlgorithm, 0}, piv.AlgorithmEC256, false},
{"SHA256WithRSA", args{apiv1.SHA256WithRSA, 0}, piv.AlgorithmRSA2048, false},
{"SHA512WithRSA", args{apiv1.SHA512WithRSA, 0}, piv.AlgorithmRSA2048, false},
{"SHA256WithRSAPSS", args{apiv1.SHA256WithRSAPSS, 0}, piv.AlgorithmRSA2048, false},
{"SHA512WithRSAPSS", args{apiv1.SHA512WithRSAPSS, 0}, piv.AlgorithmRSA2048, false},
{"ECDSAWithSHA256", args{apiv1.ECDSAWithSHA256, 0}, piv.AlgorithmEC256, false},
{"ECDSAWithSHA384", args{apiv1.ECDSAWithSHA384, 0}, piv.AlgorithmEC384, false},
{"PureEd25519", args{apiv1.PureEd25519, 0}, piv.AlgorithmEd25519, false},
{"SHA256WithRSA 2048", args{apiv1.SHA256WithRSA, 2048}, piv.AlgorithmRSA2048, false},
{"SHA512WithRSA 2048", args{apiv1.SHA512WithRSA, 2048}, piv.AlgorithmRSA2048, false},
{"SHA256WithRSAPSS 2048", args{apiv1.SHA256WithRSAPSS, 2048}, piv.AlgorithmRSA2048, false},
{"SHA512WithRSAPSS 2048", args{apiv1.SHA512WithRSAPSS, 2048}, piv.AlgorithmRSA2048, false},
{"SHA256WithRSA 1024", args{apiv1.SHA256WithRSA, 1024}, piv.AlgorithmRSA1024, false},
{"SHA512WithRSA 1024", args{apiv1.SHA512WithRSA, 1024}, piv.AlgorithmRSA1024, false},
{"SHA256WithRSAPSS 1024", args{apiv1.SHA256WithRSAPSS, 1024}, piv.AlgorithmRSA1024, false},
{"SHA512WithRSAPSS 1024", args{apiv1.SHA512WithRSAPSS, 1024}, piv.AlgorithmRSA1024, false},
{"fail 4096", args{apiv1.SHA256WithRSA, 4096}, 0, true},
{"fail unknown", args{apiv1.SignatureAlgorithm(100), 0}, 0, true},
{"fail default case", args{apiv1.SignatureAlgorithm(1000), 0}, 0, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := getSerialNumber(tt.args.yk)
got, err := getSignatureAlgorithm(tt.args.alg, tt.args.bits)
if (err != nil) != tt.wantErr {
t.Errorf("getSerialNumber() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("getSignatureAlgorithm() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("getSerialNumber() = %v, want %v", got, tt.want)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("getSignatureAlgorithm() = %v, want %v", got, tt.want)
}
})
}
Expand Down
10 changes: 7 additions & 3 deletions x509util/extensions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,9 @@ func TestKeyUsage_MarshalJSON(t *testing.T) {
t.Errorf("KeyUsage.MarshalJSON() = %q, want %q", string(got), tt.want)
}
var unmarshaled KeyUsage
unmarshaled.UnmarshalJSON(got)
if err := unmarshaled.UnmarshalJSON(got); err != nil {
t.Errorf("KeyUsage.UnmarshalJSON() error = %v", err)
}
if unmarshaled != tt.k {
t.Errorf("KeyUsage.UnmarshalJSON(keyUsage.MarshalJSON) = %v, want %v", unmarshaled, tt.k)
}
Expand Down Expand Up @@ -582,7 +584,7 @@ func TestExtKeyUsage_MarshalJSON(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.eku.MarshalJSON()
if (err != nil) != tt.wantErr {
t.Fatalf("ExtKeyUsage.MarshalJSON() = error = %v, wantErr %v", err, tt.wantErr)
t.Fatalf("ExtKeyUsage.MarshalJSON() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return
Expand All @@ -591,7 +593,9 @@ func TestExtKeyUsage_MarshalJSON(t *testing.T) {
t.Errorf("ExtKeyUsage.MarshalJSON() = %q, want %q", string(got), tt.want)
}
var unmarshaled ExtKeyUsage
unmarshaled.UnmarshalJSON(got)
if err := unmarshaled.UnmarshalJSON(got); err != nil {
t.Errorf("ExtKeyUsage.UnmarshalJSON() error = %v", err)
}
if !reflect.DeepEqual(unmarshaled, tt.eku) {
t.Errorf("ExtKeyUsage.UnmarshalJSON(ExtKeyUsage.MarshalJSON) = %v, want %v", unmarshaled, tt.eku)
}
Expand Down

0 comments on commit 54a1c86

Please sign in to comment.