diff --git a/client/vaultclient/vaultclient.go b/client/vaultclient/vaultclient.go index b5a2c3ba802c..5e47358d3af2 100644 --- a/client/vaultclient/vaultclient.go +++ b/client/vaultclient/vaultclient.go @@ -194,28 +194,36 @@ func (c *vaultClient) isTracked(id string) bool { return ok } +// isRunning returns true if the client is running. +func (c *vaultClient) isRunning() bool { + c.lock.RLock() + defer c.lock.RUnlock() + return c.running +} + // Starts the renewal loop of vault client func (c *vaultClient) Start() { + c.lock.Lock() + defer c.lock.Unlock() + if !c.config.IsEnabled() || c.running { return } - c.lock.Lock() c.running = true - c.lock.Unlock() go c.run() } // Stops the renewal loop of vault client func (c *vaultClient) Stop() { + c.lock.Lock() + defer c.lock.Unlock() + if !c.config.IsEnabled() || !c.running { return } - c.lock.Lock() - defer c.lock.Unlock() - c.running = false close(c.stopCh) } @@ -235,7 +243,7 @@ func (c *vaultClient) DeriveToken(alloc *structs.Allocation, taskNames []string) if !c.config.IsEnabled() { return nil, fmt.Errorf("vault client not enabled") } - if !c.running { + if !c.isRunning() { return nil, fmt.Errorf("vault client is not running") } @@ -505,7 +513,7 @@ func (c *vaultClient) run() { } var renewalCh <-chan time.Time - for c.config.IsEnabled() && c.running { + for c.config.IsEnabled() && c.isRunning() { // Fetches the candidate for next renewal renewalReq, renewalTime := c.nextRenewal() if renewalTime.IsZero() { diff --git a/client/vaultclient/vaultclient_test.go b/client/vaultclient/vaultclient_test.go index 84068a311eae..c62400a61b39 100644 --- a/client/vaultclient/vaultclient_test.go +++ b/client/vaultclient/vaultclient_test.go @@ -77,8 +77,11 @@ func TestVaultClient_TokenRenewals(t *testing.T) { }(errCh) } - if c.heap.Length() != num { - t.Fatalf("bad: heap length: expected: %d, actual: %d", num, c.heap.Length()) + c.lock.Lock() + length := c.heap.Length() + c.lock.Unlock() + if length != num { + t.Fatalf("bad: heap length: expected: %d, actual: %d", num, length) } time.Sleep(time.Duration(testutil.TestMultiplier()) * time.Second) @@ -89,8 +92,11 @@ func TestVaultClient_TokenRenewals(t *testing.T) { } } - if c.heap.Length() != 0 { - t.Fatalf("bad: heap length: expected: 0, actual: %d", c.heap.Length()) + c.lock.Lock() + length = c.heap.Length() + c.lock.Unlock() + if length != 0 { + t.Fatalf("bad: heap length: expected: 0, actual: %d", length) } }