diff --git a/builtin/logical/database/backend_test.go b/builtin/logical/database/backend_test.go index a613b03ac276..39c5a49a32da 100644 --- a/builtin/logical/database/backend_test.go +++ b/builtin/logical/database/backend_test.go @@ -12,10 +12,12 @@ import ( "testing" "time" + "github.com/go-test/deep" "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/helper/pluginutil" vaulthttp "github.com/hashicorp/vault/http" "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/logical/framework" "github.com/hashicorp/vault/plugins/database/postgresql" "github.com/hashicorp/vault/plugins/helper/database/dbutil" "github.com/hashicorp/vault/vault" @@ -185,53 +187,170 @@ func TestBackend_config_connection(t *testing.T) { config := logical.TestBackendConfig() config.StorageView = &logical.InmemStorage{} config.System = sys - b, err := Factory(context.Background(), config) + lb, err := Factory(context.Background(), config) if err != nil { t.Fatal(err) } + b, ok := lb.(*databaseBackend) + if !ok { + t.Fatal("could not convert to database backend") + } defer b.Cleanup(context.Background()) - configData := map[string]interface{}{ - "connection_url": "sample_connection_url", - "plugin_name": "postgresql-database-plugin", - "verify_connection": false, - "allowed_roles": []string{"*"}, - } + // Test creation + { + configData := map[string]interface{}{ + "connection_url": "sample_connection_url", + "someotherdata": "testing", + "plugin_name": "postgresql-database-plugin", + "verify_connection": false, + "allowed_roles": []string{"*"}, + "name": "plugin-test", + } - configReq := &logical.Request{ - Operation: logical.UpdateOperation, - Path: "config/plugin-test", - Storage: config.StorageView, - Data: configData, - } - resp, err = b.HandleRequest(context.Background(), configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } + configReq := &logical.Request{ + Operation: logical.CreateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } - expected := map[string]interface{}{ - "plugin_name": "postgresql-database-plugin", - "connection_details": map[string]interface{}{ - "connection_url": "sample_connection_url", - }, - "allowed_roles": []string{"*"}, - "root_credentials_rotate_statements": []string{}, + exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{ + Raw: configData, + Schema: pathConfigurePluginConnection(b).Fields, + }) + if err != nil { + t.Fatal(err) + } + if exists { + t.Fatal("expected not exists") + } + + resp, err = b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": "sample_connection_url", + "someotherdata": "testing", + }, + "allowed_roles": []string{"*"}, + "root_credentials_rotate_statements": []string{}, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } } - configReq.Operation = logical.ReadOperation - resp, err = b.HandleRequest(context.Background(), configReq) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) + + // Test existence check and an update to a single connection detail parameter + { + configData := map[string]interface{}{ + "connection_url": "sample_convection_url", + "verify_connection": false, + "name": "plugin-test", + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } + + exists, err := b.connectionExistenceCheck()(context.Background(), configReq, &framework.FieldData{ + Raw: configData, + Schema: pathConfigurePluginConnection(b).Fields, + }) + if err != nil { + t.Fatal(err) + } + if !exists { + t.Fatal("expected exists") + } + + resp, err = b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": "sample_convection_url", + "someotherdata": "testing", + }, + "allowed_roles": []string{"*"}, + "root_credentials_rotate_statements": []string{}, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } } - delete(resp.Data["connection_details"].(map[string]interface{}), "name") - if !reflect.DeepEqual(expected, resp.Data) { - t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + // Test an update to a non-details value + { + configData := map[string]interface{}{ + "verify_connection": false, + "allowed_roles": []string{"flu", "barre"}, + "name": "plugin-test", + } + + configReq := &logical.Request{ + Operation: logical.UpdateOperation, + Path: "config/plugin-test", + Storage: config.StorageView, + Data: configData, + } + + resp, err = b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + expected := map[string]interface{}{ + "plugin_name": "postgresql-database-plugin", + "connection_details": map[string]interface{}{ + "connection_url": "sample_convection_url", + "someotherdata": "testing", + }, + "allowed_roles": []string{"flu", "barre"}, + "root_credentials_rotate_statements": []string{}, + } + configReq.Operation = logical.ReadOperation + resp, err = b.HandleRequest(context.Background(), configReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + delete(resp.Data["connection_details"].(map[string]interface{}), "name") + if !reflect.DeepEqual(expected, resp.Data) { + t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + } } - configReq.Operation = logical.ListOperation - configReq.Data = nil - configReq.Path = "config/" - resp, err = b.HandleRequest(context.Background(), configReq) + req := &logical.Request{ + Operation: logical.ListOperation, + Storage: config.StorageView, + Path: "config/", + } + resp, err = b.HandleRequest(context.Background(), req) if err != nil { t.Fatal(err) } @@ -403,44 +522,98 @@ func TestBackend_basic(t *testing.T) { if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - // Get creds - data = map[string]interface{}{} - req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "creds/plugin-role-test", - Storage: config.StorageView, - Data: data, - } - credsResp, err = b.HandleRequest(context.Background(), req) - if err != nil || (credsResp != nil && credsResp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, credsResp) - } - // Test for #3812 - if credsResp.Secret.TTL != 5*time.Minute { - t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL) - } - if !testCredsExist(t, credsResp, connURL) { - t.Fatalf("Creds should exist") - } - // Revoke creds - resp, err = b.HandleRequest(context.Background(), &logical.Request{ - Operation: logical.RevokeOperation, - Storage: config.StorageView, - Secret: &logical.Secret{ - InternalData: map[string]interface{}{ - "secret_type": "creds", - "username": credsResp.Data["username"], - "role": "plugin-role-test", + // Get creds and revoke when the role stays in existence + { + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(context.Background(), req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + // Test for #3812 + if credsResp.Secret.TTL != 5*time.Minute { + t.Fatalf("unexpected TTL of %d", credsResp.Secret.TTL) + } + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } + + // Revoke creds + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.RevokeOperation, + Storage: config.StorageView, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": credsResp.Data["username"], + "role": "plugin-role-test", + }, }, - }, - }) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should not exist") + } } - if testCredsExist(t, credsResp, connURL) { - t.Fatalf("Creds should not exist") + // Get creds and revoke using embedded revocation data + { + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "creds/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + credsResp, err = b.HandleRequest(context.Background(), req) + if err != nil || (credsResp != nil && credsResp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, credsResp) + } + if !testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should exist") + } + + // Delete role, forcing us to rely on embedded data + req = &logical.Request{ + Operation: logical.DeleteOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + // Revoke creds + resp, err = b.HandleRequest(context.Background(), &logical.Request{ + Operation: logical.RevokeOperation, + Storage: config.StorageView, + Secret: &logical.Secret{ + InternalData: map[string]interface{}{ + "secret_type": "creds", + "username": credsResp.Data["username"], + "role": "plugin-role-test", + "db_name": "plugin-test", + "revocation_statements": []string(nil), + }, + }, + }) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + if testCredsExist(t, credsResp, connURL) { + t.Fatalf("Creds should not exist") + } } } @@ -545,7 +718,7 @@ func TestBackend_connectionCrud(t *testing.T) { "connection_url": connURL, }, "allowed_roles": []string{"plugin-role-test"}, - "root_credentials_rotate_statements": []string{}, + "root_credentials_rotate_statements": []string(nil), } req.Operation = logical.ReadOperation resp, err = b.HandleRequest(context.Background(), req) @@ -554,8 +727,8 @@ func TestBackend_connectionCrud(t *testing.T) { } delete(resp.Data["connection_details"].(map[string]interface{}), "name") - if !reflect.DeepEqual(expected, resp.Data) { - t.Fatalf("bad: expected:%#v\nactual:%#v\n", expected, resp.Data) + if diff := deep.Equal(resp.Data, expected); diff != nil { + t.Fatal(diff) } // Reset Connection @@ -626,10 +799,14 @@ func TestBackend_roleCrud(t *testing.T) { config.StorageView = &logical.InmemStorage{} config.System = sys - b, err := Factory(context.Background(), config) + lb, err := Factory(context.Background(), config) if err != nil { t.Fatal(err) } + b, ok := lb.(*databaseBackend) + if !ok { + t.Fatal("could not convert to db backend") + } defer b.Cleanup(context.Background()) cleanup, connURL := preparePostgresTestContainer(t, config.StorageView, b) @@ -651,52 +828,148 @@ func TestBackend_roleCrud(t *testing.T) { t.Fatalf("err:%s resp:%#v\n", err, resp) } - // Create a role - data = map[string]interface{}{ - "db_name": "plugin-test", - "creation_statements": testRole, - "revocation_statements": defaultRevocationSQL, - "default_ttl": "5m", - "max_ttl": "10m", - } - req = &logical.Request{ - Operation: logical.UpdateOperation, - Path: "roles/plugin-role-test", - Storage: config.StorageView, - Data: data, - } - resp, err = b.HandleRequest(context.Background(), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } + // Test role creation + { + data = map[string]interface{}{ + "db_name": "plugin-test", + "creation_statements": testRole, + "revocation_statements": defaultRevocationSQL, + "default_ttl": "5m", + "max_ttl": "10m", + } + req = &logical.Request{ + Operation: logical.CreateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } - // Read the role - data = map[string]interface{}{} - req = &logical.Request{ - Operation: logical.ReadOperation, - Path: "roles/plugin-role-test", - Storage: config.StorageView, - Data: data, - } - resp, err = b.HandleRequest(context.Background(), req) - if err != nil || (resp != nil && resp.IsError()) { - t.Fatalf("err:%s resp:%#v\n", err, resp) - } + exists, err := b.pathRoleExistenceCheck()(context.Background(), req, &framework.FieldData{ + Raw: data, + Schema: pathRoles(b).Fields, + }) + if err != nil { + t.Fatal(err) + } + if exists { + t.Fatal("expected not exists") + } - expected := dbplugin.Statements{ - Creation: []string{strings.TrimSpace(testRole)}, - Revocation: []string{strings.TrimSpace(defaultRevocationSQL)}, - } + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + expected := dbplugin.Statements{ + Creation: []string{strings.TrimSpace(testRole)}, + Revocation: []string{strings.TrimSpace(defaultRevocationSQL)}, + } + + actual := dbplugin.Statements{ + Creation: resp.Data["creation_statements"].([]string), + Revocation: resp.Data["revocation_statements"].([]string), + Rollback: resp.Data["rollback_statements"].([]string), + Renewal: resp.Data["renew_statements"].([]string), + } - actual := dbplugin.Statements{ - Creation: resp.Data["creation_statements"].([]string), - Revocation: resp.Data["revocation_statements"].([]string), - Rollback: resp.Data["rollback_statements"].([]string), - Renewal: resp.Data["renew_statements"].([]string), + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Statements did not match, expected %#v, got %#v", expected, actual) + } + + if diff := deep.Equal(resp.Data["db_name"], "plugin-test"); diff != nil { + t.Fatal(diff) + } + if diff := deep.Equal(resp.Data["default_ttl"], float64(300)); diff != nil { + t.Fatal(diff) + } + if diff := deep.Equal(resp.Data["max_ttl"], float64(600)); diff != nil { + t.Fatal(diff) + } } - if !reflect.DeepEqual(expected, actual) { - t.Fatalf("Statements did not match, expected %#v, got %#v", expected, actual) + // Test role modification + { + data = map[string]interface{}{ + "name": "plugin-role-test", + "rollback_statements": testRole, + "renew_statements": defaultRevocationSQL, + "max_ttl": "7m", + } + req = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v\n", err, resp) + } + + exists, err := b.pathRoleExistenceCheck()(context.Background(), req, &framework.FieldData{ + Raw: data, + Schema: pathRoles(b).Fields, + }) + if err != nil { + t.Fatal(err) + } + if !exists { + t.Fatal("expected exists") + } + + // Read the role + data = map[string]interface{}{} + req = &logical.Request{ + Operation: logical.ReadOperation, + Path: "roles/plugin-role-test", + Storage: config.StorageView, + Data: data, + } + resp, err = b.HandleRequest(context.Background(), req) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%s resp:%#v\n", err, resp) + } + + expected := dbplugin.Statements{ + Creation: []string{strings.TrimSpace(testRole)}, + Rollback: []string{strings.TrimSpace(testRole)}, + Revocation: []string{strings.TrimSpace(defaultRevocationSQL)}, + Renewal: []string{strings.TrimSpace(defaultRevocationSQL)}, + } + + actual := dbplugin.Statements{ + Creation: resp.Data["creation_statements"].([]string), + Revocation: resp.Data["revocation_statements"].([]string), + Rollback: resp.Data["rollback_statements"].([]string), + Renewal: resp.Data["renew_statements"].([]string), + } + + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Statements did not match, expected %#v, got %#v", expected, actual) + } + + if diff := deep.Equal(resp.Data["db_name"], "plugin-test"); diff != nil { + t.Fatal(diff) + } + if diff := deep.Equal(resp.Data["default_ttl"], float64(300)); diff != nil { + t.Fatal(diff) + } + if diff := deep.Equal(resp.Data["max_ttl"], float64(420)); diff != nil { + t.Fatal(diff) + } + } // Delete the role diff --git a/builtin/logical/database/path_config_connection.go b/builtin/logical/database/path_config_connection.go index 5274d9834668..bf3b378a5479 100644 --- a/builtin/logical/database/path_config_connection.go +++ b/builtin/logical/database/path_config_connection.go @@ -115,7 +115,9 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { }, }, + ExistenceCheck: b.connectionExistenceCheck(), Callbacks: map[logical.Operation]framework.OperationFunc{ + logical.CreateOperation: b.connectionWriteHandler(), logical.UpdateOperation: b.connectionWriteHandler(), logical.ReadOperation: b.connectionReadHandler(), logical.DeleteOperation: b.connectionDeleteHandler(), @@ -126,6 +128,22 @@ func pathConfigurePluginConnection(b *databaseBackend) *framework.Path { } } +func (b *databaseBackend) connectionExistenceCheck() framework.ExistenceFunc { + return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) { + name := data.Get("name").(string) + if name == "" { + return false, errors.New(`missing "name" parameter`) + } + + entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name)) + if err != nil { + return false, errors.New("failed to read connection configuration") + } + + return entry != nil, nil + } +} + func pathListPluginConnection(b *databaseBackend) *framework.Path { return &framework.Path{ Pattern: fmt.Sprintf("config/?$"), @@ -214,19 +232,46 @@ func (b *databaseBackend) connectionDeleteHandler() framework.OperationFunc { // both builtin and plugin database types. func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { - pluginName := data.Get("plugin_name").(string) - if pluginName == "" { - return logical.ErrorResponse(respErrEmptyPluginName), nil - } + verifyConnection := data.Get("verify_connection").(bool) name := data.Get("name").(string) if name == "" { return logical.ErrorResponse(respErrEmptyName), nil } - verifyConnection := data.Get("verify_connection").(bool) - allowedRoles := data.Get("allowed_roles").([]string) - rootRotationStatements := data.Get("root_rotation_statements").([]string) + // Baseline + config := &DatabaseConfig{} + + entry, err := req.Storage.Get(ctx, fmt.Sprintf("config/%s", name)) + if err != nil { + return nil, errors.New("failed to read connection configuration") + } + if entry != nil { + if err := entry.DecodeJSON(config); err != nil { + return nil, err + } + } + + if pluginNameRaw, ok := data.GetOk("plugin_name"); ok { + config.PluginName = pluginNameRaw.(string) + } else if req.Operation == logical.CreateOperation { + config.PluginName = data.Get("plugin_name").(string) + } + if config.PluginName == "" { + return logical.ErrorResponse(respErrEmptyPluginName), nil + } + + if allowedRolesRaw, ok := data.GetOk("allowed_roles"); ok { + config.AllowedRoles = allowedRolesRaw.([]string) + } else if req.Operation == logical.CreateOperation { + config.AllowedRoles = data.Get("allowed_roles").([]string) + } + + if rootRotationStatementsRaw, ok := data.GetOk("root_rotation_statements"); ok { + config.RootCredentialsRotateStatements = rootRotationStatementsRaw.([]string) + } else if req.Operation == logical.CreateOperation { + config.RootCredentialsRotateStatements = data.Get("root_rotation_statements").([]string) + } // Remove these entries from the data before we store it keyed under // ConnectionDetails. @@ -237,11 +282,26 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { delete(data.Raw, "root_rotation_statements") // Create a database plugin and initialize it. - db, err := dbplugin.PluginFactory(ctx, pluginName, b.System(), b.logger) + db, err := dbplugin.PluginFactory(ctx, config.PluginName, b.System(), b.logger) if err != nil { return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil } - connDetails, err := db.Init(ctx, data.Raw, verifyConnection) + + // If this is an update, take any new values, overwrite what was there + // before, and pass that in as the "new" set of values to the plugin, + // then save what results + if req.Operation == logical.CreateOperation { + config.ConnectionDetails = data.Raw + } else { + if config.ConnectionDetails == nil { + config.ConnectionDetails = make(map[string]interface{}) + } + for k, v := range data.Raw { + config.ConnectionDetails[k] = v + } + } + + config.ConnectionDetails, err = db.Init(ctx, config.ConnectionDetails, verifyConnection) if err != nil { db.Close() return logical.ErrorResponse(fmt.Sprintf("error creating database object: %s", err)), nil @@ -265,13 +325,7 @@ func (b *databaseBackend) connectionWriteHandler() framework.OperationFunc { } // Store it - config := &DatabaseConfig{ - ConnectionDetails: connDetails, - PluginName: pluginName, - AllowedRoles: allowedRoles, - RootCredentialsRotateStatements: rootRotationStatements, - } - entry, err := logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config) + entry, err = logical.StorageEntryJSON(fmt.Sprintf("config/%s", name), config) if err != nil { return nil, err } diff --git a/builtin/logical/database/path_creds_create.go b/builtin/logical/database/path_creds_create.go index 5b67393d5f5f..e002c6b7fa61 100644 --- a/builtin/logical/database/path_creds_create.go +++ b/builtin/logical/database/path_creds_create.go @@ -88,8 +88,10 @@ func (b *databaseBackend) pathCredsCreateRead() framework.OperationFunc { "username": username, "password": password, }, map[string]interface{}{ - "username": username, - "role": name, + "username": username, + "role": name, + "db_name": role.DBName, + "revocation_statements": role.Statements.Revocation, }) resp.Secret.TTL = role.DefaultTTL resp.Secret.MaxTTL = role.MaxTTL diff --git a/builtin/logical/database/path_roles.go b/builtin/logical/database/path_roles.go index 7fe30f2eb9c7..22a00df5201e 100644 --- a/builtin/logical/database/path_roles.go +++ b/builtin/logical/database/path_roles.go @@ -74,9 +74,11 @@ func pathRoles(b *databaseBackend) *framework.Path { }, }, + ExistenceCheck: b.pathRoleExistenceCheck(), Callbacks: map[logical.Operation]framework.OperationFunc{ logical.ReadOperation: b.pathRoleRead(), - logical.UpdateOperation: b.pathRoleCreate(), + logical.CreateOperation: b.pathRoleCreateUpdate(), + logical.UpdateOperation: b.pathRoleCreateUpdate(), logical.DeleteOperation: b.pathRoleDelete(), }, @@ -85,6 +87,17 @@ func pathRoles(b *databaseBackend) *framework.Path { } } +func (b *databaseBackend) pathRoleExistenceCheck() framework.ExistenceFunc { + return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (bool, error) { + role, err := b.Role(ctx, req.Storage, data.Get("name").(string)) + if err != nil { + return false, err + } + + return role != nil, nil + } +} + func (b *databaseBackend) pathRoleDelete() framework.OperationFunc { return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { err := req.Storage.Delete(ctx, "role/"+data.Get("name").(string)) @@ -131,44 +144,76 @@ func (b *databaseBackend) pathRoleList() framework.OperationFunc { } } -func (b *databaseBackend) pathRoleCreate() framework.OperationFunc { +func (b *databaseBackend) pathRoleCreateUpdate() framework.OperationFunc { return func(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { name := data.Get("name").(string) if name == "" { return logical.ErrorResponse("empty role name attribute given"), nil } - dbName := data.Get("db_name").(string) - if dbName == "" { - return logical.ErrorResponse("empty database name attribute given"), nil + role, err := b.Role(ctx, req.Storage, data.Get("name").(string)) + if err != nil { + return nil, err + } + if role == nil { + role = &roleEntry{} + } + + // DB Attributes + { + if dbNameRaw, ok := data.GetOk("db_name"); ok { + role.DBName = dbNameRaw.(string) + } else if req.Operation == logical.CreateOperation { + role.DBName = data.Get("db_name").(string) + } + if role.DBName == "" { + return logical.ErrorResponse("empty database name attribute"), nil + } + } + + // TTLs + { + if defaultTTLRaw, ok := data.GetOk("default_ttl"); ok { + role.DefaultTTL = time.Duration(defaultTTLRaw.(int)) * time.Second + } else if req.Operation == logical.CreateOperation { + role.DefaultTTL = time.Duration(data.Get("default_ttl").(int)) * time.Second + } + if maxTTLRaw, ok := data.GetOk("max_ttl"); ok { + role.MaxTTL = time.Duration(maxTTLRaw.(int)) * time.Second + } else if req.Operation == logical.CreateOperation { + role.MaxTTL = time.Duration(data.Get("max_ttl").(int)) * time.Second + } } - // Get statements - creationStmts := data.Get("creation_statements").([]string) - revocationStmts := data.Get("revocation_statements").([]string) - rollbackStmts := data.Get("rollback_statements").([]string) - renewStmts := data.Get("renew_statements").([]string) - - // Get TTLs - defaultTTLRaw := data.Get("default_ttl").(int) - maxTTLRaw := data.Get("max_ttl").(int) - defaultTTL := time.Duration(defaultTTLRaw) * time.Second - maxTTL := time.Duration(maxTTLRaw) * time.Second - - statements := dbplugin.Statements{ - Creation: creationStmts, - Revocation: revocationStmts, - Rollback: rollbackStmts, - Renewal: renewStmts, + // Statements + { + if creationStmtsRaw, ok := data.GetOk("creation_statements"); ok { + role.Statements.Creation = creationStmtsRaw.([]string) + } else if req.Operation == logical.CreateOperation { + role.Statements.Creation = data.Get("creation_statements").([]string) + } + + if revocationStmtsRaw, ok := data.GetOk("revocation_statements"); ok { + role.Statements.Revocation = revocationStmtsRaw.([]string) + } else if req.Operation == logical.CreateOperation { + role.Statements.Revocation = data.Get("revocation_statements").([]string) + } + + if rollbackStmtsRaw, ok := data.GetOk("rollback_statements"); ok { + role.Statements.Rollback = rollbackStmtsRaw.([]string) + } else if req.Operation == logical.CreateOperation { + role.Statements.Rollback = data.Get("rollback_statements").([]string) + } + + if renewStmtsRaw, ok := data.GetOk("renew_statements"); ok { + role.Statements.Renewal = renewStmtsRaw.([]string) + } else if req.Operation == logical.CreateOperation { + role.Statements.Renewal = data.Get("renew_statements").([]string) + } } // Store it - entry, err := logical.StorageEntryJSON("role/"+name, &roleEntry{ - DBName: dbName, - Statements: statements, - DefaultTTL: defaultTTL, - MaxTTL: maxTTL, - }) + entry, err := logical.StorageEntryJSON("role/"+name, role) if err != nil { return nil, err } diff --git a/builtin/logical/database/secret_creds.go b/builtin/logical/database/secret_creds.go index 4489b0798587..b1e3ddd87292 100644 --- a/builtin/logical/database/secret_creds.go +++ b/builtin/logical/database/secret_creds.go @@ -5,6 +5,7 @@ import ( "fmt" "time" + "github.com/hashicorp/vault/builtin/logical/database/dbplugin" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" ) @@ -91,16 +92,31 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { return nil, fmt.Errorf("no role name was provided") } + var dbName string + var statements dbplugin.Statements + role, err := b.Role(ctx, req.Storage, roleNameRaw.(string)) if err != nil { return nil, err } - if role == nil { - return nil, fmt.Errorf("error during revoke: could not find role with name %q", req.Secret.InternalData["role"]) + if role != nil { + dbName = role.DBName + statements = role.Statements + } else { + if dbNameRaw, ok := req.Secret.InternalData["db_name"]; !ok { + return nil, fmt.Errorf("error during revoke: could not find role with name %q or embedded revocation db name data", req.Secret.InternalData["role"]) + } else { + dbName = dbNameRaw.(string) + } + if statementsRaw, ok := req.Secret.InternalData["revocation_statements"]; !ok { + return nil, fmt.Errorf("error during revoke: could not find role with name %q or embedded revocation statement data", req.Secret.InternalData["role"]) + } else { + statements.Revocation = statementsRaw.([]string) + } } // Get our connection - db, err := b.GetConnection(ctx, req.Storage, role.DBName) + db, err := b.GetConnection(ctx, req.Storage, dbName) if err != nil { return nil, err } @@ -108,7 +124,7 @@ func (b *databaseBackend) secretCredsRevoke() framework.OperationFunc { db.RLock() defer db.RUnlock() - if err := db.RevokeUser(ctx, role.Statements, username); err != nil { + if err := db.RevokeUser(ctx, statements, username); err != nil { b.CloseIfShutdown(db, err) return nil, err } diff --git a/website/source/api/secret/databases/index.html.md b/website/source/api/secret/databases/index.html.md index ba1114029d77..0c0b2ea2960a 100644 --- a/website/source/api/secret/databases/index.html.md +++ b/website/source/api/secret/databases/index.html.md @@ -25,6 +25,8 @@ plugin has additional, database plugin specific, parameters for this endpoint. Please read the HTTP API for the plugin you'd wish to configure to see the full list of additional parameters. +~> This endpoint distinguishes between `create` and `update` ACL capabilities. + | Method | Path | Produces | | :------- | :--------------------------- | :--------------------- | | `POST` | `/database/config/:name` | `204 (empty body)` | @@ -208,6 +210,8 @@ $ curl \ This endpoint creates or updates a role definition. +~> This endpoint distinguishes between `create` and `update` ACL capabilities. + | Method | Path | Produces | | :------- | :--------------------------- | :--------------------- | | `POST` | `/database/roles/:name` | `204 (empty body)` |