diff --git a/bccsp/pkcs11/ecdsa.go b/bccsp/pkcs11/ecdsa.go index a7166725e26..cd882434f0d 100644 --- a/bccsp/pkcs11/ecdsa.go +++ b/bccsp/pkcs11/ecdsa.go @@ -46,6 +46,6 @@ func (csp *impl) verifyECDSA(k ecdsaPublicKey, signature, digest []byte, opts bc if csp.softVerify { return ecdsa.Verify(k.pub, digest, r, s), nil } - return csp.verifyP11ECDSA(k.ski, digest, r, s, k.pub.Curve.Params().BitSize/8) + return csp.verifyP11ECDSA(k.ski, digest, r, s, k.pub.Curve.Params().BitSize/8) } diff --git a/bccsp/pkcs11/pkcs11.go b/bccsp/pkcs11/pkcs11.go index 3d1e850d867..9107cf270b3 100644 --- a/bccsp/pkcs11/pkcs11.go +++ b/bccsp/pkcs11/pkcs11.go @@ -27,6 +27,8 @@ import ( "go.uber.org/zap/zapcore" ) +const createSessionRetries = 10 + var ( logger = flogging.MustGetLogger("bccsp_p11") sessionCacheSize = 10 @@ -35,14 +37,19 @@ var ( type impl struct { bccsp.BCCSP - slot uint - pin string - ctx *pkcs11.Ctx - sessions chan pkcs11.SessionHandle - + slot uint + pin string + ctx *pkcs11.Ctx conf *config softVerify bool immutable bool + + sessLock sync.Mutex + sessPool chan pkcs11.SessionHandle + sessions map[pkcs11.SessionHandle]struct{} + + cacheLock sync.RWMutex + handleCache map[string]pkcs11.ObjectHandle } // New WithParams returns a new instance of the software-based BCCSP @@ -60,12 +67,19 @@ func New(opts PKCS11Opts, keyStore bccsp.KeyStore) (bccsp.BCCSP, error) { return nil, errors.Wrapf(err, "Failed initializing fallback SW BCCSP") } + var sessPool chan pkcs11.SessionHandle + if sessionCacheSize > 0 { + sessPool = make(chan pkcs11.SessionHandle, sessionCacheSize) + } + csp := &impl{ - BCCSP: swCSP, - conf: conf, - sessions: make(chan pkcs11.SessionHandle, sessionCacheSize), - softVerify: opts.SoftVerify, - immutable: opts.Immutable, + BCCSP: swCSP, + conf: conf, + sessPool: sessPool, + sessions: map[pkcs11.SessionHandle]struct{}{}, + handleCache: map[string]pkcs11.ObjectHandle{}, + softVerify: opts.SoftVerify, + immutable: opts.Immutable, } return csp.initialize(opts) @@ -80,8 +94,13 @@ func (csp *impl) initialize(opts PKCS11Opts) (*impl, error) { if ctx == nil { return nil, fmt.Errorf("pkcs11: instantiation failed for %s", opts.Library) } + if err := ctx.Initialize(); err != nil { + logger.Debugf("initialize failed: %v", err) + } + + csp.ctx = ctx + csp.pin = opts.Pin - ctx.Initialize() slots, err := ctx.GetSlotList(true) if err != nil { return nil, errors.Wrap(err, "pkcs11: get slot list") @@ -93,11 +112,9 @@ func (csp *impl) initialize(opts PKCS11Opts) (*impl, error) { continue } - csp.ctx = ctx csp.slot = s - csp.pin = opts.Pin - session, err := createSession(ctx, s, opts.Pin) + session, err := csp.createSession() if err != nil { return nil, err } @@ -261,30 +278,31 @@ func (csp *impl) Decrypt(k bccsp.Key, ciphertext []byte, opts bccsp.DecrypterOpt func (csp *impl) getSession() (session pkcs11.SessionHandle, err error) { for { select { - case session = <-csp.sessions: + case session = <-csp.sessPool: if _, err = csp.ctx.GetSessionInfo(session); err == nil { logger.Debugf("Reusing existing pkcs11 session %d on slot %d\n", session, csp.slot) return session, nil } logger.Warningf("Get session info failed [%s], closing existing session and getting a new session\n", err) - csp.ctx.CloseSession(session) + csp.closeSession(session) default: // cache is empty (or completely in use), create a new session - return createSession(csp.ctx, csp.slot, csp.pin) + return csp.createSession() } } } -func createSession(ctx *pkcs11.Ctx, slot uint, pin string) (pkcs11.SessionHandle, error) { +func (csp *impl) createSession() (pkcs11.SessionHandle, error) { var sess pkcs11.SessionHandle var err error - // attempt 10 times to open a session with a 100ms delay after each attempt - for i := 0; i < 10; i++ { - sess, err = ctx.OpenSession(slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) + + // attempt to open a session with a 100ms delay after each attempt + for i := 0; i < createSessionRetries; i++ { + sess, err = csp.ctx.OpenSession(csp.slot, pkcs11.CKF_SERIAL_SESSION|pkcs11.CKF_RW_SESSION) if err == nil { - logger.Debugf("Created new pkcs11 session %d on slot %d\n", sess, slot) + logger.Debugf("Created new pkcs11 session %d on slot %d\n", sess, csp.slot) break } @@ -295,20 +313,41 @@ func createSession(ctx *pkcs11.Ctx, slot uint, pin string) (pkcs11.SessionHandle return 0, errors.Wrap(err, "OpenSession failed") } - err = ctx.Login(sess, pkcs11.CKU_USER, pin) + err = csp.ctx.Login(sess, pkcs11.CKU_USER, csp.pin) if err != nil && err != pkcs11.Error(pkcs11.CKR_USER_ALREADY_LOGGED_IN) { - return sess, errors.Wrap(err, "Login failed") + csp.ctx.CloseSession(sess) + return 0, errors.Wrap(err, "Login failed") } + + csp.sessLock.Lock() + csp.sessions[sess] = struct{}{} + csp.sessLock.Unlock() + return sess, nil } +func (csp *impl) closeSession(session pkcs11.SessionHandle) { + if err := csp.ctx.CloseSession(session); err != nil { + logger.Debug("CloseSession failed", err) + } + + csp.sessLock.Lock() + defer csp.sessLock.Unlock() + + // purge the handle cache if the last session closes + delete(csp.sessions, session) + if len(csp.sessions) == 0 { + csp.clearHandleCache() + } +} + func (csp *impl) returnSession(session pkcs11.SessionHandle) { select { - case csp.sessions <- session: + case csp.sessPool <- session: // returned session back to session cache default: // have plenty of sessions in cache, dropping - csp.ctx.CloseSession(session) + csp.closeSession(session) } } @@ -331,7 +370,7 @@ func (csp *impl) getECKey(ski []byte) (pubKey *ecdsa.PublicKey, isPriv bool, err return nil, false, fmt.Errorf("Public key not found [%s] for SKI [%s]", err, hex.EncodeToString(ski)) } - ecpt, marshaledOid, err := csp.ecPoint(session, *publicKey) + ecpt, marshaledOid, err := csp.ecPoint(session, publicKey) if err != nil { return nil, false, fmt.Errorf("Public key not found [%s] for SKI [%s]", err, hex.EncodeToString(ski)) } @@ -522,7 +561,7 @@ func (csp *impl) signP11ECDSA(ski []byte, msg []byte) (R, S *big.Int, err error) return nil, nil, fmt.Errorf("Private key not found [%s]", err) } - err = csp.ctx.SignInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil)}, *privateKey) + err = csp.ctx.SignInit(session, []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil)}, privateKey) if err != nil { return nil, nil, fmt.Errorf("Sign-initialize failed [%s]", err) } @@ -567,7 +606,7 @@ func (csp *impl) verifyP11ECDSA(ski []byte, msg []byte, R, S *big.Int, byteSize err = csp.ctx.VerifyInit( session, []*pkcs11.Mechanism{pkcs11.NewMechanism(pkcs11.CKM_ECDSA, nil)}, - *publicKey, + publicKey, ) if err != nil { return false, fmt.Errorf("PKCS11: Verify-initialize [%s]", err) @@ -590,7 +629,35 @@ const ( privateKeyType ) -func (csp *impl) findKeyPairFromSKI(session pkcs11.SessionHandle, ski []byte, keyType keyType) (*pkcs11.ObjectHandle, error) { +func (csp *impl) cachedHandle(keyType keyType, ski []byte) (pkcs11.ObjectHandle, bool) { + cacheKey := hex.EncodeToString(append([]byte{byte(keyType)}, ski...)) + csp.cacheLock.RLock() + defer csp.cacheLock.RUnlock() + + handle, ok := csp.handleCache[cacheKey] + return handle, ok +} + +func (csp *impl) cacheHandle(keyType keyType, ski []byte, handle pkcs11.ObjectHandle) { + cacheKey := hex.EncodeToString(append([]byte{byte(keyType)}, ski...)) + csp.cacheLock.Lock() + defer csp.cacheLock.Unlock() + + csp.handleCache[cacheKey] = handle +} + +func (csp *impl) clearHandleCache() { + csp.cacheLock.Lock() + defer csp.cacheLock.Unlock() + csp.handleCache = map[string]pkcs11.ObjectHandle{} +} + +func (csp *impl) findKeyPairFromSKI(session pkcs11.SessionHandle, ski []byte, keyType keyType) (pkcs11.ObjectHandle, error) { + // check for cached handle + if handle, ok := csp.cachedHandle(keyType, ski); ok { + return handle, nil + } + ktype := pkcs11.CKO_PUBLIC_KEY if keyType == privateKeyType { ktype = pkcs11.CKO_PRIVATE_KEY @@ -601,23 +668,25 @@ func (csp *impl) findKeyPairFromSKI(session pkcs11.SessionHandle, ski []byte, ke pkcs11.NewAttribute(pkcs11.CKA_ID, ski), } if err := csp.ctx.FindObjectsInit(session, template); err != nil { - return nil, err + return 0, err } // single session instance, assume one hit only objs, _, err := csp.ctx.FindObjects(session, 1) if err != nil { - return nil, err + return 0, err } if err = csp.ctx.FindObjectsFinal(session); err != nil { - return nil, err + return 0, err } - if len(objs) == 0 { - return nil, fmt.Errorf("Key not found [%s]", hex.Dump(ski)) + return 0, fmt.Errorf("Key not found [%s]", hex.Dump(ski)) } - return &objs[0], nil + // cache the found handle + csp.cacheHandle(keyType, ski, objs[0]) + + return objs[0], nil } // Fairly straightforward EC-point query, other than opencryptoki diff --git a/bccsp/pkcs11/pkcs11_test.go b/bccsp/pkcs11/pkcs11_test.go index bd088c6ace2..a2db1b3539f 100644 --- a/bccsp/pkcs11/pkcs11_test.go +++ b/bccsp/pkcs11/pkcs11_test.go @@ -1353,6 +1353,144 @@ func TestPKCS11GetSession(t *testing.T) { } } +func TestCaching(t *testing.T) { + defer func(s int) { sessionCacheSize = s }(sessionCacheSize) + opts := PKCS11Opts{ + HashFamily: "SHA2", + SecLevel: 256, + SoftVerify: false, + } + opts.Library, opts.Pin, opts.Label = FindPKCS11Lib() + + verifyHandleCache := func(t *testing.T, pi *impl, sess pkcs11.SessionHandle, k bccsp.Key) { + pubHandle, err := pi.findKeyPairFromSKI(sess, k.SKI(), publicKeyType) + require.NoError(t, err) + h, ok := pi.cachedHandle(publicKeyType, k.SKI()) + require.True(t, ok) + require.Equal(t, h, pubHandle) + + privHandle, err := pi.findKeyPairFromSKI(sess, k.SKI(), privateKeyType) + require.NoError(t, err) + h, ok = pi.cachedHandle(privateKeyType, k.SKI()) + require.True(t, ok) + require.Equal(t, h, privHandle) + } + + t.Run("SessionCacheDisabled", func(t *testing.T) { + sessionCacheSize = 0 + + provider, err := New(opts, currentKS) + require.NoError(t, err) + pi := provider.(*impl) + defer pi.ctx.Destroy() + + require.Nil(t, pi.sessPool, "sessPool channel should be nil") + require.Empty(t, pi.sessions, "sessions set should be empty") + require.Empty(t, pi.handleCache, "handleCache should be empty") + + sess1, err := pi.getSession() + require.NoError(t, err) + require.Len(t, pi.sessions, 1, "expected one open session") + + sess2, err := pi.getSession() + require.NoError(t, err) + require.Len(t, pi.sessions, 2, "expected two open sessions") + + // Generate a key + k, err := pi.KeyGen(&bccsp.ECDSAP256KeyGenOpts{Temporary: false}) + require.NoError(t, err) + verifyHandleCache(t, pi, sess1, k) + require.Len(t, pi.handleCache, 2, "expected two handles in handle cache") + + pi.returnSession(sess1) + require.Len(t, pi.sessions, 1, "expected one open session") + verifyHandleCache(t, pi, sess1, k) + require.Len(t, pi.handleCache, 2, "expected two handles in handle cache") + + pi.returnSession(sess2) + require.Empty(t, pi.sessions, "expected sessions to be empty") + require.Empty(t, pi.handleCache, "expected handles to be cleared") + + pi.slot = ^uint(0) // break OpenSession + _, err = pi.getSession() + require.EqualError(t, err, "OpenSession failed: pkcs11: 0x3: CKR_SLOT_ID_INVALID") + require.Empty(t, pi.sessions, "expected sessions to be empty") + }) + + t.Run("SessionCacheEnabled", func(t *testing.T) { + sessionCacheSize = 1 + + provider, err := New(opts, currentKS) + require.NoError(t, err) + pi := provider.(*impl) + defer pi.ctx.Destroy() + + require.NotNil(t, pi.sessPool, "sessPool channel should not be nil") + require.Equal(t, 1, cap(pi.sessPool)) + require.Len(t, pi.sessions, 1, "sessions should contain login session") + require.Len(t, pi.sessPool, 1, "sessionPool should hold login session") + require.Empty(t, pi.handleCache, "handleCache should be empty") + + sess1, err := pi.getSession() + require.NoError(t, err) + require.Len(t, pi.sessions, 1, "expected one open session (sess1 from login)") + require.Len(t, pi.sessPool, 0, "sessionPool should be empty") + + sess2, err := pi.getSession() + require.NoError(t, err) + require.Len(t, pi.sessions, 2, "expected two open sessions (sess1 and sess2)") + require.Len(t, pi.sessPool, 0, "sessionPool should be empty") + + // Generate a key + k, err := pi.KeyGen(&bccsp.ECDSAP256KeyGenOpts{Temporary: false}) + require.NoError(t, err) + verifyHandleCache(t, pi, sess1, k) + require.Len(t, pi.handleCache, 2, "expected two handles in handle cache") + + pi.returnSession(sess1) + require.Len(t, pi.sessions, 2, "expected two open sessions (sess2 in-use, sess1 cached)") + require.Len(t, pi.sessPool, 1, "sessionPool should have one handle (sess1)") + verifyHandleCache(t, pi, sess1, k) + require.Len(t, pi.handleCache, 2, "expected two handles in handle cache") + + pi.returnSession(sess2) + require.Len(t, pi.sessions, 1, "expected one cached session (sess1)") + require.Len(t, pi.sessPool, 1, "sessionPool should have one handle (sess1)") + require.Len(t, pi.handleCache, 2, "expected two handles in handle cache") + + sess1, err = pi.getSession() + require.NoError(t, err) + require.Len(t, pi.sessions, 1, "expected one open session (sess1)") + require.Len(t, pi.sessPool, 0, "sessionPool should be empty") + require.Len(t, pi.handleCache, 2, "expected two handles in handle cache") + + pi.slot = ^uint(0) // break OpenSession + _, err = pi.getSession() + require.EqualError(t, err, "OpenSession failed: pkcs11: 0x3: CKR_SLOT_ID_INVALID") + require.Len(t, pi.sessions, 1, "expected one active session (sess1)") + require.Len(t, pi.sessPool, 0, "sessionPool should be empty") + require.Len(t, pi.handleCache, 2, "expected two handles in handle cache") + + // Return a busted session that should be cached + pi.returnSession(pkcs11.SessionHandle(^uint(0))) + require.Len(t, pi.sessions, 1, "expected one active session (sess1)") + require.Len(t, pi.sessPool, 1, "sessionPool should contain busted session") + require.Len(t, pi.handleCache, 2, "expected two handles in handle cache") + + // Return sess1 that should be discarded + pi.returnSession(sess1) + require.Len(t, pi.sessions, 0, "expected sess1 to be removed") + require.Len(t, pi.sessPool, 1, "sessionPool should contain busted session") + require.Empty(t, pi.handleCache, "expected handles to be purged on removal of last tracked session") + + // Try to get broken session from cache + _, err = pi.getSession() + require.EqualError(t, err, "OpenSession failed: pkcs11: 0x3: CKR_SLOT_ID_INVALID") + require.Empty(t, pi.sessions, "expected sessions to be empty") + require.Len(t, pi.sessPool, 0, "sessionPool should be empty") + }) +} + func TestPKCS11ECKeySignVerify(t *testing.T) { msg1 := []byte("This is my very authentic message") msg2 := []byte("This is my very unauthentic message") @@ -1372,7 +1510,6 @@ func TestPKCS11ECKeySignVerify(t *testing.T) { } R, S, err := currentBCCSP.(*impl).signP11ECDSA(key, hash1) - if err != nil { t.Fatalf("Failed signing message [%s]", err) }