From b4e9af6fe656ca480a93f148d7980b819ee4ef12 Mon Sep 17 00:00:00 2001 From: Leonard Cohnen Date: Thu, 25 May 2023 11:41:31 +0200 Subject: [PATCH 1/3] add LoadCachedKey function --- client/import.go | 4 ++-- client/keys.go | 24 ++++++++++++++++++------ client/session.go | 30 +++++++++++++++--------------- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/client/import.go b/client/import.go index 72f796c3..e88e4409 100644 --- a/client/import.go +++ b/client/import.go @@ -40,7 +40,7 @@ func (k *Key) Import(blob *pb.ImportBlob) ([]byte, error) { } defer tpm2.FlushContext(k.rw, handle) - unsealSession, err := newPCRSession(k.rw, internal.PCRSelection(blob.Pcrs)) + unsealSession, err := NewPCRSession(k.rw, internal.PCRSelection(blob.Pcrs)) if err != nil { return nil, err } @@ -76,7 +76,7 @@ func (k *Key) ImportSigningKey(blob *pb.ImportBlob) (key *Key, err error) { if key.pubArea, _, _, err = tpm2.ReadPublic(k.rw, handle); err != nil { return } - if key.session, err = newPCRSession(k.rw, internal.PCRSelection(blob.Pcrs)); err != nil { + if key.session, err = NewPCRSession(k.rw, internal.PCRSelection(blob.Pcrs)); err != nil { return } return key, key.finish() diff --git a/client/keys.go b/client/keys.go index da148b42..a6ffc279 100644 --- a/client/keys.go +++ b/client/keys.go @@ -27,7 +27,7 @@ type Key struct { pubArea tpm2.Public pubKey crypto.PublicKey name tpm2.Name - session session + session Session cert *x509.Certificate } @@ -114,6 +114,19 @@ func GceAttestationKeyECC(rw io.ReadWriter) (*Key, error) { return akEcc, nil } +// LoadCachedKey loads a key from cachedHandle. +// If the key is not found, an error is returned. +// This function will no overwrite an existing key, unlike NewCachedKey. +func LoadCachedKey(rw io.ReadWriter, cachedHandle tpmutil.Handle, keySession Session) (k *Key, err error) { + cachedPub, _, _, err := tpm2.ReadPublic(rw, cachedHandle) + if err != nil { + return nil, fmt.Errorf("failed to read public area of cached key: %w", err) + } + + k = &Key{rw: rw, handle: cachedHandle, pubArea: cachedPub, session: keySession} + return k, k.finish() +} + // KeyFromNvIndex generates and loads a key under the provided parent // (possibly a hierarchy root tpm2.Handle{Owner|Endorsement|Platform|Null}) // using the template stored at the provided nvdata index. @@ -182,8 +195,7 @@ func NewKey(rw io.ReadWriter, parent tpmutil.Handle, template tpm2.Public) (k *K return nil, fmt.Errorf("unsupported parent handle: %x", parent) } - handle, pubArea, _, _, _, _, err := - tpm2.CreatePrimaryEx(rw, parent, tpm2.PCRSelection{}, "", "", template) + handle, pubArea, _, _, _, _, err := tpm2.CreatePrimaryEx(rw, parent, tpm2.PCRSelection{}, "", "", template) if err != nil { return nil, err } @@ -211,11 +223,11 @@ func (k *Key) finish() error { // We determine the right type of session based on the auth policy if k.session == nil { if bytes.Equal(k.pubArea.AuthPolicy, defaultEKAuthPolicy()) { - if k.session, err = newEKSession(k.rw); err != nil { + if k.session, err = NewEKSession(k.rw); err != nil { return err } } else if len(k.pubArea.AuthPolicy) == 0 { - k.session = nullSession{} + k.session = NullSession{} } else { return fmt.Errorf("unknown auth policy when creating key") } @@ -407,7 +419,7 @@ func (k *Key) Unseal(in *pb.SealedBytes, opts UnsealOpts) ([]byte, error) { sel.PCRs = append(sel.PCRs, int(pcr)) } - session, err := newPCRSession(k.rw, sel) + session, err := NewPCRSession(k.rw, sel) if err != nil { return nil, fmt.Errorf("failed to create session: %w", err) } diff --git a/client/session.go b/client/session.go index 4e496980..5c56eb8f 100644 --- a/client/session.go +++ b/client/session.go @@ -7,7 +7,7 @@ import ( "github.com/google/go-tpm/tpmutil" ) -type session interface { +type Session interface { io.Closer Auth() (tpm2.AuthCommand, error) } @@ -31,42 +31,42 @@ func startAuthSession(rw io.ReadWriter) (session tpmutil.Handle, err error) { return } -type pcrSession struct { +type PCRSession struct { rw io.ReadWriter session tpmutil.Handle sel tpm2.PCRSelection } -func newPCRSession(rw io.ReadWriter, sel tpm2.PCRSelection) (session, error) { +func NewPCRSession(rw io.ReadWriter, sel tpm2.PCRSelection) (Session, error) { if len(sel.PCRs) == 0 { - return nullSession{}, nil + return NullSession{}, nil } session, err := startAuthSession(rw) - return pcrSession{rw, session, sel}, err + return PCRSession{rw, session, sel}, err } -func (p pcrSession) Auth() (auth tpm2.AuthCommand, err error) { +func (p PCRSession) Auth() (auth tpm2.AuthCommand, err error) { if err = tpm2.PolicyPCR(p.rw, p.session, nil, p.sel); err != nil { return } return tpm2.AuthCommand{Session: p.session, Attributes: tpm2.AttrContinueSession}, nil } -func (p pcrSession) Close() error { +func (p PCRSession) Close() error { return tpm2.FlushContext(p.rw, p.session) } -type ekSession struct { +type EKSession struct { rw io.ReadWriter session tpmutil.Handle } -func newEKSession(rw io.ReadWriter) (session, error) { +func NewEKSession(rw io.ReadWriter) (Session, error) { session, err := startAuthSession(rw) - return ekSession{rw, session}, err + return EKSession{rw, session}, err } -func (e ekSession) Auth() (auth tpm2.AuthCommand, err error) { +func (e EKSession) Auth() (auth tpm2.AuthCommand, err error) { nullAuth := tpm2.AuthCommand{Session: tpm2.HandlePasswordSession, Attributes: tpm2.AttrContinueSession} if _, _, err = tpm2.PolicySecret(e.rw, tpm2.HandleEndorsement, nullAuth, e.session, nil, nil, nil, 0); err != nil { return @@ -74,16 +74,16 @@ func (e ekSession) Auth() (auth tpm2.AuthCommand, err error) { return tpm2.AuthCommand{Session: e.session, Attributes: tpm2.AttrContinueSession}, nil } -func (e ekSession) Close() error { +func (e EKSession) Close() error { return tpm2.FlushContext(e.rw, e.session) } -type nullSession struct{} +type NullSession struct{} -func (n nullSession) Auth() (auth tpm2.AuthCommand, err error) { +func (n NullSession) Auth() (auth tpm2.AuthCommand, err error) { return tpm2.AuthCommand{Session: tpm2.HandlePasswordSession, Attributes: tpm2.AttrContinueSession}, nil } -func (n nullSession) Close() error { +func (n NullSession) Close() error { return nil } From c50780f036323a3b00a333a5f4f0b69169884b3c Mon Sep 17 00:00:00 2001 From: Leonard Cohnen Date: Wed, 21 Jun 2023 16:29:53 +0200 Subject: [PATCH 2/3] add TestLoadCachedKey --- client/keys_test.go | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/client/keys_test.go b/client/keys_test.go index b8e41ea7..da1a7342 100644 --- a/client/keys_test.go +++ b/client/keys_test.go @@ -227,6 +227,7 @@ func getTestCert(t *testing.T, pubKey crypto.PublicKey, parentCert *x509.Certifi return cert, certKey } + func TestSetCert(t *testing.T) { rwc := test.GetTPM(t) defer client.CheckedClose(t, rwc) @@ -261,3 +262,41 @@ func TestSetCertFailsIfCertificateIsNotForKey(t *testing.T) { t.Error("SetCert() returned successfully, expected error") } } + +func TestLoadCachedKey(t *testing.T) { + rwc := test.GetTPM(t) + defer client.CheckedClose(t, rwc) + + createdKey, err := client.NewKey(rwc, tpm2.HandleNull, client.SRKTemplateRSA()) + if err != nil { + t.Fatalf("NewKey() returned error: %v", err) + } + defer createdKey.Close() + + handles := []struct { + name string + handle tpmutil.Handle + errExpected bool + }{ + {"successful retrieval with handle", createdKey.Handle(), false}, + {"error for bad handle", tpmutil.Handle(0x0), true}, + } + + for _, k := range handles { + t.Run(k.name, func(t *testing.T) { + loadedKey, err := client.LoadCachedKey(rwc, createdKey.Handle(), client.NullSession{}) + if k.errExpected && err == nil { + t.Fatal("LoadCachedKey() returned successfully, expected error") + } else if !k.errExpected && err != nil { + t.Fatalf("LoadCachedKey() returned error: %v", err) + } else if k.errExpected { + return + } + defer loadedKey.Close() + + if !reflect.DeepEqual(createdKey, loadedKey) { + t.Errorf("Loaded key does not match created key") + } + }) + } +} From c4a1ba7ab98b2744d82fc125a671075f13facbc9 Mon Sep 17 00:00:00 2001 From: Leonard Cohnen Date: Thu, 22 Jun 2023 15:54:34 +0200 Subject: [PATCH 3/3] fix linter issues --- client/keys.go | 2 +- client/session.go | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/client/keys.go b/client/keys.go index a6ffc279..59e9aa19 100644 --- a/client/keys.go +++ b/client/keys.go @@ -116,7 +116,7 @@ func GceAttestationKeyECC(rw io.ReadWriter) (*Key, error) { // LoadCachedKey loads a key from cachedHandle. // If the key is not found, an error is returned. -// This function will no overwrite an existing key, unlike NewCachedKey. +// This function will not overwrite an existing key, unlike NewCachedKey. func LoadCachedKey(rw io.ReadWriter, cachedHandle tpmutil.Handle, keySession Session) (k *Key, err error) { cachedPub, _, _, err := tpm2.ReadPublic(rw, cachedHandle) if err != nil { diff --git a/client/session.go b/client/session.go index 5c56eb8f..bf3a1d11 100644 --- a/client/session.go +++ b/client/session.go @@ -7,6 +7,7 @@ import ( "github.com/google/go-tpm/tpmutil" ) +// Session is an interface for TPM sessions. type Session interface { io.Closer Auth() (tpm2.AuthCommand, error) @@ -31,12 +32,14 @@ func startAuthSession(rw io.ReadWriter) (session tpmutil.Handle, err error) { return } +// PCRSession is a TPM session that is bound to a set of PCRs. type PCRSession struct { rw io.ReadWriter session tpmutil.Handle sel tpm2.PCRSelection } +// NewPCRSession creates a new PCRSession. func NewPCRSession(rw io.ReadWriter, sel tpm2.PCRSelection) (Session, error) { if len(sel.PCRs) == 0 { return NullSession{}, nil @@ -45,6 +48,7 @@ func NewPCRSession(rw io.ReadWriter, sel tpm2.PCRSelection) (Session, error) { return PCRSession{rw, session, sel}, err } +// Auth returns the AuthCommand for the session. func (p PCRSession) Auth() (auth tpm2.AuthCommand, err error) { if err = tpm2.PolicyPCR(p.rw, p.session, nil, p.sel); err != nil { return @@ -52,20 +56,24 @@ func (p PCRSession) Auth() (auth tpm2.AuthCommand, err error) { return tpm2.AuthCommand{Session: p.session, Attributes: tpm2.AttrContinueSession}, nil } +// Close closes the session. func (p PCRSession) Close() error { return tpm2.FlushContext(p.rw, p.session) } +// EKSession is a TPM session that is bound to the EK. type EKSession struct { rw io.ReadWriter session tpmutil.Handle } +// NewEKSession creates a new EKSession. func NewEKSession(rw io.ReadWriter) (Session, error) { session, err := startAuthSession(rw) return EKSession{rw, session}, err } +// Auth returns the AuthCommand for the session. func (e EKSession) Auth() (auth tpm2.AuthCommand, err error) { nullAuth := tpm2.AuthCommand{Session: tpm2.HandlePasswordSession, Attributes: tpm2.AttrContinueSession} if _, _, err = tpm2.PolicySecret(e.rw, tpm2.HandleEndorsement, nullAuth, e.session, nil, nil, nil, 0); err != nil { @@ -74,16 +82,20 @@ func (e EKSession) Auth() (auth tpm2.AuthCommand, err error) { return tpm2.AuthCommand{Session: e.session, Attributes: tpm2.AttrContinueSession}, nil } +// Close closes the session. func (e EKSession) Close() error { return tpm2.FlushContext(e.rw, e.session) } +// NullSession is a TPM session that is not bound to anything. type NullSession struct{} +// Auth returns the AuthCommand for the session. func (n NullSession) Auth() (auth tpm2.AuthCommand, err error) { return tpm2.AuthCommand{Session: tpm2.HandlePasswordSession, Attributes: tpm2.AttrContinueSession}, nil } +// Close closes the session. func (n NullSession) Close() error { return nil }