Skip to content

Commit

Permalink
Add restart tracker to ecs-agent module
Browse files Browse the repository at this point in the history
Add restart tracker to ecs-agent module
  • Loading branch information
timj-hh committed May 3, 2024
2 parents 1b72d8f + 2d4969a commit 3b4c985
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 0 deletions.
99 changes: 99 additions & 0 deletions ecs-agent/api/container/restart/restart_tracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package restart

import (
"fmt"
"sync"
"time"

apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status"
)

type RestartTracker struct {
RestartCount int `json:"restartCount,omitempty"`
LastRestartAt time.Time `json:"lastRestartAt,omitempty"`
restartPolicy RestartPolicy
lock sync.RWMutex
}

// RestartPolicy represents a policy that contains key information considered when
// deciding whether or not a container should be restarted after it has exited.
type RestartPolicy struct {
Enabled bool `json:"enabled"`
IgnoredExitCodes []int `json:"ignoredExitCodes"`
RestartAttemptPeriod time.Duration `json:"restartAttemptPeriod"`
}

func NewRestartTracker(restartPolicy RestartPolicy) *RestartTracker {
return &RestartTracker{
restartPolicy: restartPolicy,
}
}

func (rt *RestartTracker) GetLastRestartAt() time.Time {
rt.lock.RLock()
defer rt.lock.RUnlock()
return rt.LastRestartAt
}

func (rt *RestartTracker) GetRestartCount() int {
rt.lock.RLock()
defer rt.lock.RUnlock()
return rt.RestartCount
}

// RecordRestart updates the restart tracker's metadata after a restart has occurred.
// This metadata is used to calculate when restarts should occur and track how many
// have occurred. It is not the job of this method to determine if a restart should
// occur or restart the container.
func (rt *RestartTracker) RecordRestart() {
rt.lock.Lock()
defer rt.lock.Unlock()
rt.RestartCount++
rt.LastRestartAt = time.Now()
}

// ShouldRestart returns whether the container should restart and a reason string
// explaining why not. The reset attempt period will be calculated first
// with LastRestart at, using the passed in startedAt if it does not exist.
func (rt *RestartTracker) ShouldRestart(exitCode *int, startedAt time.Time,
desiredStatus apicontainerstatus.ContainerStatus) (bool, string) {
rt.lock.RLock()
defer rt.lock.RUnlock()

if !rt.restartPolicy.Enabled {
return false, "restart policy is not enabled"
}
if desiredStatus == apicontainerstatus.ContainerStopped {
return false, "container's desired status is stopped"
}
if exitCode == nil {
return false, "exit code is nil"
}
for _, ignoredCode := range rt.restartPolicy.IgnoredExitCodes {
if ignoredCode == *exitCode {
return false, fmt.Sprintf("exit code %d should be ignored", *exitCode)
}
}

startTime := startedAt
if !rt.LastRestartAt.IsZero() {
startTime = rt.LastRestartAt
}
if time.Since(startTime) < rt.restartPolicy.RestartAttemptPeriod {
return false, "attempt reset period has not elapsed"
}
return true, ""
}
176 changes: 176 additions & 0 deletions ecs-agent/api/container/restart/restart_tracker_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
//go:build unit
// +build unit

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package restart

import (
"testing"
"time"

apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status"

"github.com/stretchr/testify/assert"
)

func TestShouldRestart(t *testing.T) {
ignoredCode := 0
rt := NewRestartTracker(RestartPolicy{
Enabled: false,
IgnoredExitCodes: []int{ignoredCode},
RestartAttemptPeriod: 60 * time.Second,
})
testCases := []struct {
name string
rp RestartPolicy
exitCode int
startedAt time.Time
desiredStatus apicontainerstatus.ContainerStatus
expected bool
expectedReason string
}{
{
name: "restart policy disabled",
rp: RestartPolicy{
Enabled: false,
IgnoredExitCodes: []int{ignoredCode},
RestartAttemptPeriod: 60 * time.Second,
},
exitCode: 1,
startedAt: time.Now().Add(2 * time.Minute),
desiredStatus: apicontainerstatus.ContainerRunning,
expected: false,
expectedReason: "restart policy is not enabled",
},
{
name: "ignored exit code",
rp: RestartPolicy{
Enabled: true,
IgnoredExitCodes: []int{ignoredCode},
RestartAttemptPeriod: 60 * time.Second,
},
exitCode: 0,
startedAt: time.Now().Add(-2 * time.Minute),
desiredStatus: apicontainerstatus.ContainerRunning,
expected: false,
expectedReason: "exit code 0 should be ignored",
},
{
name: "non ignored exit code",
rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second},
exitCode: 1,
startedAt: time.Now().Add(-2 * time.Minute),
desiredStatus: apicontainerstatus.ContainerRunning,
expected: true,
expectedReason: "",
},
{
name: "nil exit code",
rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second},
exitCode: -1,
startedAt: time.Now().Add(-2 * time.Minute),
desiredStatus: apicontainerstatus.ContainerRunning,
expected: false,
expectedReason: "exit code is nil",
},
{
name: "desired status stopped",
rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second},
exitCode: 1,
startedAt: time.Now().Add(2 * time.Minute),
desiredStatus: apicontainerstatus.ContainerStopped,
expected: false,
expectedReason: "container's desired status is stopped",
},
{
name: "attempt reset period not elapsed",
rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second},
exitCode: 1,
startedAt: time.Now(),
desiredStatus: apicontainerstatus.ContainerRunning,
expected: false,
expectedReason: "attempt reset period has not elapsed",
},
{
name: "attempt reset period not elapsed within one second",
rp: RestartPolicy{Enabled: true, IgnoredExitCodes: []int{ignoredCode}, RestartAttemptPeriod: 60 * time.Second},
exitCode: 1,
startedAt: time.Now().Add(-time.Second * 59),
desiredStatus: apicontainerstatus.ContainerRunning,
expected: false,
expectedReason: "attempt reset period has not elapsed",
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
rt.restartPolicy = tc.rp

// Because we cannot instantiate int pointers directly,
// check for the exit code and leave this int pointer as nil
// if there is no value to override it.
var exitCodeAdjusted *int
if tc.exitCode != -1 {
exitCodeAdjusted = &tc.exitCode
}

shouldRestart, reason := rt.ShouldRestart(exitCodeAdjusted, tc.startedAt, tc.desiredStatus)
assert.Equal(t, tc.expected, shouldRestart)
assert.Equal(t, tc.expectedReason, reason)
})
}
}

func TestShouldRestartUsesLastRestart(t *testing.T) {
rt := NewRestartTracker(RestartPolicy{
Enabled: true,
IgnoredExitCodes: []int{0},
RestartAttemptPeriod: 60 * time.Second,
})
exitCode := 1

shouldRestart, reason := rt.ShouldRestart(&exitCode, time.Now().Add(-61*time.Second), apicontainerstatus.ContainerRunning)
assert.True(t, shouldRestart)

// After restarting, we should inform restart decisions with LastRestartedAt instead of the passed in startedAt time.
rt.RecordRestart()
shouldRestart, reason = rt.ShouldRestart(&exitCode, time.Now().Add(-61*time.Second), apicontainerstatus.ContainerRunning)
assert.False(t, shouldRestart)
assert.Equal(t, "attempt reset period has not elapsed", reason)
}

func TestRecordRestart(t *testing.T) {
rt := NewRestartTracker(RestartPolicy{
Enabled: false,
RestartAttemptPeriod: 60 * time.Second,
})
assert.Equal(t, 0, rt.RestartCount)
for i := 1; i < 1000; i++ {
restartAt := time.Now()
rt.RecordRestart()
assert.Equal(t, i, rt.RestartCount)
assert.Equal(t, restartAt.Round(time.Second), rt.GetLastRestartAt().Round(time.Second))
}
}

func TestRecordRestartPolicy(t *testing.T) {
rt := NewRestartTracker(RestartPolicy{
Enabled: false,
RestartAttemptPeriod: 60 * time.Second,
})
assert.Equal(t, 0, rt.RestartCount)
assert.Equal(t, 0, len(rt.restartPolicy.IgnoredExitCodes))
assert.NotNil(t, rt.restartPolicy)
}

0 comments on commit 3b4c985

Please sign in to comment.