From 7b655b35fc32640c99120bb8b911bb1acfc60997 Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Thu, 11 Jan 2018 11:15:21 -0500 Subject: [PATCH 1/3] updating locking --- builtin/logical/database/backend.go | 4 +- builtin/logical/database/path_creds_create.go | 18 +++++---- builtin/logical/database/secret_creds.go | 39 ++++++++++--------- 3 files changed, 33 insertions(+), 28 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index fdf4ec25b1a0..975e013cad08 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -198,13 +198,13 @@ func (b *databaseBackend) clearConnection(name string) { } } +// closeIfShutdown clears the connections if shutdown. The lock should be held +// entering this function. func (b *databaseBackend) closeIfShutdown(name string, err error) { // Plugin has shutdown, close it so next call can reconnect. switch err { case rpc.ErrShutdown, dbplugin.ErrPluginShutdown: - b.Lock() b.clearConnection(name) - b.Unlock() } } diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index ff1de51dabf3..6335f611c877 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -56,7 +56,8 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // Grab the read lock b.RLock() - var unlockFunc func() = b.RUnlock + unlockFunc := b.RUnlock + defer func() { unlockFunc() }() // Get the Database object db, ok := b.getDBObj(role.DBName) @@ -66,11 +67,14 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { b.Lock() unlockFunc = b.Unlock - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + // Check again + db, ok = b.getDBObj(role.DBName) + if !ok { + // Create a new DB object + db, err = b.createDBObj(ctx, req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } } } @@ -83,8 +87,6 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // Create the user username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration) - // Unlock - unlockFunc() if err != nil { b.closeIfShutdown(role.DBName, err) return nil, err diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index f1b50142c79e..da76e4c4dd77 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -50,7 +50,8 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Grab the read lock b.RLock() - var unlockFunc func() = b.RUnlock + unlockFunc := b.RUnlock + defer func() { unlockFunc() }() // Get the Database object db, ok := b.getDBObj(role.DBName) @@ -60,19 +61,20 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { b.Lock() unlockFunc = b.Unlock - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + // Check again + db, ok = b.getDBObj(role.DBName) + if !ok { + // Create a new DB object + db, err = b.createDBObj(ctx, req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } } } // Make sure we increase the VALID UNTIL endpoint for this user. if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { err := db.RenewUser(ctx, role.Statements, username, expireTime) - // Unlock - unlockFunc() if err != nil { b.closeIfShutdown(role.DBName, err) return nil, err @@ -109,7 +111,8 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Grab the read lock b.RLock() - var unlockFunc func() = b.RUnlock + unlockFunc := b.RUnlock + defer func() { unlockFunc() }() // Get our connection db, ok := b.getDBObj(role.DBName) @@ -119,18 +122,18 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { b.Lock() unlockFunc = b.Unlock - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + // Check again + db, ok = b.getDBObj(role.DBName) + if !ok { + // Create a new DB object + db, err = b.createDBObj(ctx, req.Storage, role.DBName) + if err != nil { + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) + } } } - err = db.RevokeUser(ctx, role.Statements, username) - // Unlock - unlockFunc() - if err != nil { + if err := db.RevokeUser(ctx, role.Statements, username); err != nil { b.closeIfShutdown(role.DBName, err) return nil, err } From 092bdcb21667f193c7eeb9c957823becefab4fed Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Thu, 11 Jan 2018 11:31:21 -0500 Subject: [PATCH 2/3] ensure closeIfShutdown has the write lock --- builtin/logical/database/backend.go | 4 ++-- builtin/logical/database/path_creds_create.go | 5 ++++- builtin/logical/database/secret_creds.go | 8 ++++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/builtin/logical/database/backend.go b/builtin/logical/database/backend.go index 975e013cad08..fdf4ec25b1a0 100644 --- a/builtin/logical/database/backend.go +++ b/builtin/logical/database/backend.go @@ -198,13 +198,13 @@ func (b *databaseBackend) clearConnection(name string) { } } -// closeIfShutdown clears the connections if shutdown. The lock should be held -// entering this function. func (b *databaseBackend) closeIfShutdown(name string, err error) { // Plugin has shutdown, close it so next call can reconnect. switch err { case rpc.ErrShutdown, dbplugin.ErrPluginShutdown: + b.Lock() b.clearConnection(name) + b.Unlock() } } diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 6335f611c877..a271f4700cea 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -57,7 +57,6 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // Grab the read lock b.RLock() unlockFunc := b.RUnlock - defer func() { unlockFunc() }() // Get the Database object db, ok := b.getDBObj(role.DBName) @@ -73,6 +72,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // Create a new DB object db, err = b.createDBObj(ctx, req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } } @@ -88,6 +88,7 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { // Create the user username, password, err := db.CreateUser(ctx, role.Statements, usernameConfig, expiration) if err != nil { + unlockFunc() b.closeIfShutdown(role.DBName, err) return nil, err } @@ -100,6 +101,8 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { "role": name, }) resp.Secret.TTL = role.DefaultTTL + + unlockFunc() return resp, nil } } diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index da76e4c4dd77..f6aabc48a44a 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -51,7 +51,6 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Grab the read lock b.RLock() unlockFunc := b.RUnlock - defer func() { unlockFunc() }() // Get the Database object db, ok := b.getDBObj(role.DBName) @@ -67,6 +66,7 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { // Create a new DB object db, err = b.createDBObj(ctx, req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } } @@ -76,11 +76,13 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { if expireTime := resp.Secret.ExpirationTime(); !expireTime.IsZero() { err := db.RenewUser(ctx, role.Statements, username, expireTime) if err != nil { + unlockFunc() b.closeIfShutdown(role.DBName, err) return nil, err } } + unlockFunc() return resp, nil } } @@ -112,7 +114,6 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Grab the read lock b.RLock() unlockFunc := b.RUnlock - defer func() { unlockFunc() }() // Get our connection db, ok := b.getDBObj(role.DBName) @@ -128,16 +129,19 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { // Create a new DB object db, err = b.createDBObj(ctx, req.Storage, role.DBName) if err != nil { + unlockFunc() return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } } } if err := db.RevokeUser(ctx, role.Statements, username); err != nil { + unlockFunc() b.closeIfShutdown(role.DBName, err) return nil, err } + unlockFunc() return resp, nil } } From 1f51628c4b3d5d3efb291746e30d99694bb9257a Mon Sep 17 00:00:00 2001 From: Chris Hoffman Date: Thu, 11 Jan 2018 12:45:56 -0500 Subject: [PATCH 3/3] removing object check since this is happening in create --- builtin/logical/database/path_creds_create.go | 14 ++++------ builtin/logical/database/secret_creds.go | 28 +++++++------------ 2 files changed, 15 insertions(+), 27 deletions(-) diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index a271f4700cea..16a56823ae47 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -66,15 +66,11 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { b.Lock() unlockFunc = b.Unlock - // Check again - db, ok = b.getDBObj(role.DBName) - if !ok { - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) - } + // Create a new DB object + db, err = b.createDBObj(ctx, req.Storage, role.DBName) + if err != nil { + unlockFunc() + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } } diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index f6aabc48a44a..8f2a7b9a4a3e 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -60,15 +60,11 @@ func (b *databaseBackend) secretCredsRenew() framework.OperationFunc { b.Lock() unlockFunc = b.Unlock - // Check again - db, ok = b.getDBObj(role.DBName) - if !ok { - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) - } + // Create a new DB object + db, err = b.createDBObj(ctx, req.Storage, role.DBName) + if err != nil { + unlockFunc() + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } } @@ -123,15 +119,11 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { b.Lock() unlockFunc = b.Unlock - // Check again - db, ok = b.getDBObj(role.DBName) - if !ok { - // Create a new DB object - db, err = b.createDBObj(ctx, req.Storage, role.DBName) - if err != nil { - unlockFunc() - return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) - } + // Create a new DB object + db, err = b.createDBObj(ctx, req.Storage, role.DBName) + if err != nil { + unlockFunc() + return nil, fmt.Errorf("cound not retrieve db with name: %s, got error: %s", role.DBName, err) } }