diff --git a/builtin/credential/approle/path_login.go b/builtin/credential/approle/path_login.go index cdfb5d47234c..2257f5f4af0e 100644 --- a/builtin/credential/approle/path_login.go +++ b/builtin/credential/approle/path_login.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "strings" + "time" "github.com/hashicorp/vault/logical" "github.com/hashicorp/vault/logical/framework" @@ -100,12 +101,17 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, data return nil, fmt.Errorf("role %s does not exist during renewal", roleName) } - resp, err := framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(ctx, req, data) - if err != nil { - return nil, err + // If a period is provided, set that as part of resp.Auth.Period and return a + // response immediately. Let expiration manager handle renewal from there on. + if role.Period > time.Duration(0) { + resp := &logical.Response{ + Auth: req.Auth, + } + resp.Auth.Period = role.Period + return resp, nil } - resp.Auth.Period = role.Period - return resp, nil + + return framework.LeaseExtend(role.TokenTTL, role.TokenMaxTTL, b.System())(ctx, req, data) } const pathLoginHelpSys = "Issue a token based on the credentials supplied" diff --git a/builtin/credential/approle/path_login_test.go b/builtin/credential/approle/path_login_test.go index 25af416c761c..44d2ec2c0c64 100644 --- a/builtin/credential/approle/path_login_test.go +++ b/builtin/credential/approle/path_login_test.go @@ -3,6 +3,7 @@ package approle import ( "context" "testing" + "time" "github.com/hashicorp/vault/logical" ) @@ -48,12 +49,106 @@ func TestAppRole_RoleLogin(t *testing.T) { RemoteAddr: "127.0.0.1", }, } - resp, err = b.HandleRequest(context.Background(), loginReq) + loginResp, err := b.HandleRequest(context.Background(), loginReq) + if err != nil || (loginResp != nil && loginResp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, loginResp) + } + + if loginResp.Auth == nil { + t.Fatalf("expected a non-nil auth object in the response") + } + + // Test renewal + renewReq := generateRenewRequest(storage, loginResp.Auth) + + resp, err = b.HandleRequest(context.Background(), renewReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + if resp.Auth.TTL != 400*time.Second { + t.Fatalf("expected period value from response to be 400s, got: %s", resp.Auth.TTL) + } + + /// + // Test renewal with period + /// + + // Create role + period := 600 * time.Second + roleData := map[string]interface{}{ + "policies": "a,b,c", + "period": period.String(), + } + roleReq := &logical.Request{ + Operation: logical.CreateOperation, + Path: "role/" + "role-period", + Storage: storage, + Data: roleData, + } + resp, err = b.HandleRequest(context.Background(), roleReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + roleRoleIDReq = &logical.Request{ + Operation: logical.ReadOperation, + Path: "role/role-period/role-id", + Storage: storage, + } + resp, err = b.HandleRequest(context.Background(), roleRoleIDReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + roleID = resp.Data["role_id"] + + roleSecretIDReq = &logical.Request{ + Operation: logical.UpdateOperation, + Path: "role/role-period/secret-id", + Storage: storage, + } + resp, err = b.HandleRequest(context.Background(), roleSecretIDReq) if err != nil || (resp != nil && resp.IsError()) { t.Fatalf("err:%v resp:%#v", err, resp) } + secretID = resp.Data["secret_id"] + + loginData["role_id"] = roleID + loginData["secret_id"] = secretID + + loginResp, err = b.HandleRequest(context.Background(), loginReq) + if err != nil || (loginResp != nil && loginResp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, loginResp) + } - if resp.Auth == nil { + if loginResp.Auth == nil { t.Fatalf("expected a non-nil auth object in the response") } + + renewReq = generateRenewRequest(storage, loginResp.Auth) + + resp, err = b.HandleRequest(context.Background(), renewReq) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("err:%v resp:%#v", err, resp) + } + + if resp.Auth.Period != period { + t.Fatalf("expected period value of %d in the response, got: %s", period, resp.Auth.Period) + } +} + +func generateRenewRequest(s logical.Storage, auth *logical.Auth) *logical.Request { + renewReq := &logical.Request{ + Operation: logical.RenewOperation, + Storage: s, + Auth: &logical.Auth{}, + } + renewReq.Auth.InternalData = auth.InternalData + renewReq.Auth.Metadata = auth.Metadata + renewReq.Auth.LeaseOptions = auth.LeaseOptions + renewReq.Auth.Policies = auth.Policies + renewReq.Auth.IssueTime = time.Now() + renewReq.Auth.Period = auth.Period + + return renewReq } diff --git a/builtin/credential/aws/backend_test.go b/builtin/credential/aws/backend_test.go index 3b38d1435195..1c6f4201853a 100644 --- a/builtin/credential/aws/backend_test.go +++ b/builtin/credential/aws/backend_test.go @@ -1615,6 +1615,40 @@ func TestBackendAcc_LoginWithCallerIdentity(t *testing.T) { if cachedArn == "" { t.Errorf("got empty ARN back from user ID cache; expected full arn") } + + // Test for renewal with period + period := 600 * time.Second + roleData["period"] = period.String() + roleRequest.Path = "role/" + testValidRoleName + resp, err = b.HandleRequest(context.Background(), roleRequest) + if err != nil || (resp != nil && resp.IsError()) { + t.Fatalf("bad: failed to create wildcard role: resp:%#v\nerr:%v", resp, err) + } + + loginData["role"] = testValidRoleName + resp, err = b.HandleRequest(context.Background(), loginRequest) + if err != nil { + t.Fatal(err) + } + if resp == nil || resp.Auth == nil || resp.IsError() { + t.Fatalf("bad: expected valid login: resp:%#v", resp) + } + + renewReq = generateRenewRequest(storage, resp.Auth) + resp, err = b.pathLoginRenew(context.Background(), renewReq, empty_login_fd) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("got nil response from renew") + } + if resp.IsError() { + t.Fatalf("got error when renewing: %#v", *resp) + } + + if resp.Auth.Period != period { + t.Fatalf("expected a period value of %s in the response, got: %s", period, resp.Auth.Period) + } } func generateRenewRequest(s logical.Storage, auth *logical.Auth) *logical.Request { @@ -1627,6 +1661,7 @@ func generateRenewRequest(s logical.Storage, auth *logical.Auth) *logical.Reques renewReq.Auth.LeaseOptions = auth.LeaseOptions renewReq.Auth.Policies = auth.Policies renewReq.Auth.IssueTime = time.Now() + renewReq.Auth.Period = auth.Period return renewReq } diff --git a/builtin/credential/aws/path_login.go b/builtin/credential/aws/path_login.go index a45de134364a..7a9d96739068 100644 --- a/builtin/credential/aws/path_login.go +++ b/builtin/credential/aws/path_login.go @@ -975,12 +975,17 @@ func (b *backend) pathLoginRenewIam(ctx context.Context, req *logical.Request, d } } - resp, err := framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(ctx, req, data) - if err != nil { - return nil, err + // If a period is provided, set that as part of resp.Auth.Period and return a + // response immediately. Let expiration manager handle renewal from there on. + if roleEntry.Period > time.Duration(0) { + resp := &logical.Response{ + Auth: req.Auth, + } + resp.Auth.Period = roleEntry.Period + return resp, nil } - resp.Auth.Period = roleEntry.Period - return resp, nil + + return framework.LeaseExtend(roleEntry.TTL, roleEntry.MaxTTL, b.System())(ctx, req, data) } func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { @@ -1060,12 +1065,17 @@ func (b *backend) pathLoginRenewEc2(ctx context.Context, req *logical.Request, d return nil, err } - resp, err := framework.LeaseExtend(roleEntry.TTL, shortestMaxTTL, b.System())(ctx, req, data) - if err != nil { - return nil, err + // If a period is provided, set that as part of resp.Auth.Period and return a + // response immediately. Let expiration manager handle renewal from there on. + if roleEntry.Period > time.Duration(0) { + resp := &logical.Response{ + Auth: req.Auth, + } + resp.Auth.Period = roleEntry.Period + return resp, nil } - resp.Auth.Period = roleEntry.Period - return resp, nil + + return framework.LeaseExtend(roleEntry.TTL, shortestMaxTTL, b.System())(ctx, req, data) } func (b *backend) pathLoginUpdateIam(ctx context.Context, req *logical.Request, data *framework.FieldData) (*logical.Response, error) { diff --git a/builtin/credential/cert/backend_test.go b/builtin/credential/cert/backend_test.go index 020e964010b0..9f5319376e6d 100644 --- a/builtin/credential/cert/backend_test.go +++ b/builtin/credential/cert/backend_test.go @@ -1195,6 +1195,7 @@ func Test_Renew(t *testing.T) { req.Auth.LeaseOptions = resp.Auth.LeaseOptions req.Auth.Policies = resp.Auth.Policies req.Auth.IssueTime = time.Now() + req.Auth.Period = resp.Auth.Period // Normal renewal resp, err = b.pathLoginRenew(context.Background(), req, empty_login_fd) @@ -1238,6 +1239,29 @@ func Test_Renew(t *testing.T) { t.Fatalf("got error: %#v", *resp) } + // Add period value to cert entry + period := 350 * time.Second + fd.Raw["period"] = period.String() + resp, err = b.pathCertWrite(context.Background(), req, fd) + if err != nil { + t.Fatal(err) + } + + resp, err = b.pathLoginRenew(context.Background(), req, empty_login_fd) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("got nil response from renew") + } + if resp.IsError() { + t.Fatalf("got error: %#v", *resp) + } + + if resp.Auth.Period != period { + t.Fatalf("expected a period value of %s in the response, got: %s", period, resp.Auth.Period) + } + // Delete CA, make sure we can't renew resp, err = b.pathCertDelete(context.Background(), req, fd) if err != nil { diff --git a/builtin/credential/cert/path_login.go b/builtin/credential/cert/path_login.go index 293d43db865c..5955c340587e 100644 --- a/builtin/credential/cert/path_login.go +++ b/builtin/credential/cert/path_login.go @@ -175,12 +175,17 @@ func (b *backend) pathLoginRenew(ctx context.Context, req *logical.Request, d *f return nil, fmt.Errorf("policies have changed, not renewing") } - resp, err := framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(ctx, req, d) - if err != nil { - return nil, err + // If a period is provided, set that as part of resp.Auth.Period and return a + // response immediately. Let expiration manager handle renewal from there on. + if cert.Period > time.Duration(0) { + resp := &logical.Response{ + Auth: req.Auth, + } + resp.Auth.Period = cert.Period + return resp, nil } - resp.Auth.Period = cert.Period - return resp, nil + + return framework.LeaseExtend(cert.TTL, cert.MaxTTL, b.System())(ctx, req, d) } func (b *backend) verifyCredentials(req *logical.Request, d *framework.FieldData) (*ParsedCert, *logical.Response, error) { diff --git a/vault/expiration_integ_test.go b/vault/expiration_integ_test.go new file mode 100644 index 000000000000..8763a1471b5c --- /dev/null +++ b/vault/expiration_integ_test.go @@ -0,0 +1,168 @@ +package vault_test + +import ( + "encoding/json" + "testing" + "time" + + "github.com/hashicorp/vault/api" + "github.com/hashicorp/vault/builtin/credential/approle" + vaulthttp "github.com/hashicorp/vault/http" + "github.com/hashicorp/vault/logical" + "github.com/hashicorp/vault/vault" +) + +func TestExpiration_RenewToken_TestCluster(t *testing.T) { + // Use a TestCluster and the approle backend to test renewal + coreConfig := &vault.CoreConfig{ + CredentialBackends: map[string]logical.Factory{ + "approle": approle.Factory, + }, + } + cluster := vault.NewTestCluster(t, coreConfig, &vault.TestClusterOptions{ + HandlerFunc: vaulthttp.Handler, + }) + cluster.Start() + defer cluster.Cleanup() + + core := cluster.Cores[0].Core + vault.TestWaitActive(t, core) + client := cluster.Cores[0].Client + + // Mount the auth backend + err := client.Sys().EnableAuthWithOptions("approle", &api.EnableAuthOptions{ + Type: "approle", + }) + if err != nil { + t.Fatal(err) + } + + // Tune the mount + err = client.Sys().TuneMount("auth/approle", api.MountConfigInput{ + DefaultLeaseTTL: "5s", + MaxLeaseTTL: "5s", + }) + if err != nil { + t.Fatal(err) + } + + // Create role + resp, err := client.Logical().Write("auth/approle/role/role-period", map[string]interface{}{ + "period": "5s", + }) + if err != nil { + t.Fatal(err) + } + + // Get role_id + resp, err = client.Logical().Read("auth/approle/role/role-period/role-id") + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected a response for fetching the role-id") + } + roleID := resp.Data["role_id"] + + // Get secret_id + resp, err = client.Logical().Write("auth/approle/role/role-period/secret-id", map[string]interface{}{}) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected a response for fetching the secret-id") + } + secretID := resp.Data["secret_id"] + + // Login + resp, err = client.Logical().Write("auth/approle/login", map[string]interface{}{ + "role_id": roleID, + "secret_id": secretID, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected a response for login") + } + if resp.Auth == nil { + t.Fatal("expected auth object from response") + } + if resp.Auth.ClientToken == "" { + t.Fatal("expected a client token") + } + + roleToken := resp.Auth.ClientToken + // Wait 3 seconds + time.Sleep(3 * time.Second) + + // Renew + resp, err = client.Logical().Write("auth/token/renew", map[string]interface{}{ + "token": roleToken, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected a response for renew") + } + + // Perform token lookup and verify TTL + resp, err = client.Auth().Token().Lookup(roleToken) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected a response for token lookup") + } + + ttlRaw, ok := resp.Data["ttl"].(json.Number) + if !ok { + t.Fatal("no ttl value found in data object") + } + ttlInt, err := ttlRaw.Int64() + if err != nil { + t.Fatalf("unable to convert ttl to int: %s", err) + } + ttl := time.Duration(ttlInt) * time.Second + if ttl < 4*time.Second { + t.Fatal("expected ttl value to be around 5s") + } + + // Wait 3 seconds + time.Sleep(3 * time.Second) + + // Do a second renewal to ensure that period can be renewed past sys/mount max_ttl + resp, err = client.Logical().Write("auth/token/renew", map[string]interface{}{ + "token": roleToken, + }) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected a response for renew") + } + + // Perform token lookup and verify TTL + resp, err = client.Auth().Token().Lookup(roleToken) + if err != nil { + t.Fatal(err) + } + if resp == nil { + t.Fatal("expected a response for token lookup") + } + + ttlRaw, ok = resp.Data["ttl"].(json.Number) + if !ok { + t.Fatal("no ttl value found in data object") + } + ttlInt, err = ttlRaw.Int64() + if err != nil { + t.Fatalf("unable to convert ttl to int: %s", err) + } + ttl = time.Duration(ttlInt) * time.Second + if ttl < 4*time.Second { + t.Fatal("expected ttl value to be around 5s") + } + +} diff --git a/vault/expiration_test.go b/vault/expiration_test.go index 0f639688de66..19f55f7a2479 100644 --- a/vault/expiration_test.go +++ b/vault/expiration_test.go @@ -790,7 +790,119 @@ func TestExpiration_RenewToken(t *testing.T) { } if auth.ClientToken != out.Auth.ClientToken { - t.Fatalf("Bad: %#v", out) + t.Fatalf("bad: %#v", out) + } +} + +func TestExpiration_RenewToken_period(t *testing.T) { + exp := mockExpiration(t) + root, err := exp.tokenStore.rootToken() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Register a token + auth := &logical.Auth{ + ClientToken: root.ID, + LeaseOptions: logical.LeaseOptions{ + TTL: time.Hour, + Renewable: true, + }, + Period: time.Minute, + } + err = exp.RegisterAuth("auth/token/login", auth) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Renew the token + out, err := exp.RenewToken(&logical.Request{}, "auth/token/login", root.ID, 0) + if err != nil { + t.Fatalf("err: %v", err) + } + + if auth.ClientToken != out.Auth.ClientToken { + t.Fatalf("bad: %#v", out) + } + + if out.Auth.TTL > time.Minute { + t.Fatalf("expected TTL to be less than 1 minute, got: %s", out.Auth.TTL) + } +} + +func TestExpiration_RenewToken_period_backend(t *testing.T) { + exp := mockExpiration(t) + root, err := exp.tokenStore.rootToken() + if err != nil { + t.Fatalf("err: %v", err) + } + + // Mount a noop backend + noop := &NoopBackend{ + Response: &logical.Response{ + Auth: &logical.Auth{ + LeaseOptions: logical.LeaseOptions{ + TTL: 10 * time.Second, + Renewable: true, + }, + Period: 5 * time.Second, + }, + }, + DefaultLeaseTTL: 5 * time.Second, + MaxLeaseTTL: 5 * time.Second, + } + + _, barrier, _ := mockBarrier(t) + view := NewBarrierView(barrier, credentialBarrierPrefix) + meUUID, err := uuid.GenerateUUID() + if err != nil { + t.Fatal(err) + } + err = exp.router.Mount(noop, "auth/foo/", &MountEntry{Path: "auth/foo/", Type: "noop", UUID: meUUID, Accessor: "noop-accessor"}, view) + if err != nil { + t.Fatal(err) + } + + // Register a token + auth := &logical.Auth{ + ClientToken: root.ID, + LeaseOptions: logical.LeaseOptions{ + TTL: 10 * time.Second, + Renewable: true, + IssueTime: time.Now(), + }, + Period: 5 * time.Second, + } + + err = exp.RegisterAuth("auth/foo/login", auth) + if err != nil { + t.Fatalf("err: %v", err) + } + + // Wait 3 seconds + time.Sleep(3 * time.Second) + resp, err := exp.RenewToken(&logical.Request{}, "auth/foo/login", root.ID, 0) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatal("expected a response") + } + if resp.Auth.TTL > 5*time.Second { + t.Fatalf("expected TTL to be less than or equal to period, got: %s", resp.Auth.TTL) + } + + // Wait another 3 seconds. If period works correctly, this should not fail + time.Sleep(3 * time.Second) + resp, err = exp.RenewToken(&logical.Request{}, "auth/foo/login", root.ID, 0) + if err != nil { + t.Fatalf("err: %v", err) + } + if resp == nil { + t.Fatal("expected a response") + } + if resp.Auth.TTL < 4*time.Second || resp.Auth.TTL > 5*time.Second { + t.Fatalf("expected TTL to be around period's value, got: %s", resp.Auth.TTL) } } @@ -1222,7 +1334,7 @@ func TestExpiration_renewAuthEntry(t *testing.T) { }, } _, barrier, _ := mockBarrier(t) - view := NewBarrierView(barrier, "auth/foo/") + view := NewBarrierView(barrier, "auth/") meUUID, err := uuid.GenerateUUID() if err != nil { t.Fatal(err) diff --git a/vault/router_test.go b/vault/router_test.go index a3940e092fa6..6eb70d17bf5e 100644 --- a/vault/router_test.go +++ b/vault/router_test.go @@ -19,12 +19,14 @@ import ( type NoopBackend struct { sync.Mutex - Root []string - Login []string - Paths []string - Requests []*logical.Request - Response *logical.Response - Invalidations []string + Root []string + Login []string + Paths []string + Requests []*logical.Request + Response *logical.Response + Invalidations []string + DefaultLeaseTTL time.Duration + MaxLeaseTTL time.Duration } func (n *NoopBackend) HandleRequest(ctx context.Context, req *logical.Request) (*logical.Response, error) { @@ -53,9 +55,19 @@ func (n *NoopBackend) SpecialPaths() *logical.Paths { } func (n *NoopBackend) System() logical.SystemView { + defaultLeaseTTLVal := time.Hour * 24 + maxLeaseTTLVal := time.Hour * 24 * 32 + if n.DefaultLeaseTTL > 0 { + defaultLeaseTTLVal = n.DefaultLeaseTTL + } + + if n.MaxLeaseTTL > 0 { + maxLeaseTTLVal = n.MaxLeaseTTL + } + return logical.StaticSystemView{ - DefaultLeaseTTLVal: time.Hour * 24, - MaxLeaseTTLVal: time.Hour * 24 * 32, + DefaultLeaseTTLVal: defaultLeaseTTLVal, + MaxLeaseTTLVal: maxLeaseTTLVal, } }