Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix loading CSI driver container from state if it exists #3970

Merged
merged 1 commit into from
Oct 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions agent/api/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -1508,3 +1508,16 @@ func (c *Container) GetContainerPortRangeMap() map[string]string {
defer c.lock.RUnlock()
return c.ContainerPortRangeMap
}

func (c *Container) IsManagedDaemonContainer() bool {
c.lock.RLock()
defer c.lock.RUnlock()
return c.Type == ContainerManagedDaemon
}

func (c *Container) GetImageName() string {
c.lock.RLock()
defer c.lock.RUnlock()
containerImage := strings.Split(c.Image, ":")[0]
return containerImage
}
42 changes: 41 additions & 1 deletion agent/api/container/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,53 @@ func TestIsInternal(t *testing.T) {
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("IsInternal shoukd return %t for %s", tc.internal, tc.container.String()),
t.Run(fmt.Sprintf("IsInternal should return %t for %s", tc.internal, tc.container.String()),
func(t *testing.T) {
assert.Equal(t, tc.internal, tc.container.IsInternal())
})
}
}

func TestIsManagedDaemonContainer(t *testing.T) {
testCases := []struct {
container *Container
internal bool
isManagedDaemon bool
}{
{&Container{}, false, false},
{&Container{Type: ContainerNormal, Image: "someImage:latest"}, false, false},
{&Container{Type: ContainerManagedDaemon, Image: "someImage:latest"}, true, true},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("IsManagedDaemonContainer should return %t for %s", tc.isManagedDaemon, tc.container.String()),
func(t *testing.T) {
assert.Equal(t, tc.internal, tc.container.IsInternal())
ok := tc.container.IsManagedDaemonContainer()
assert.Equal(t, tc.isManagedDaemon, ok)
})
}
}

func TestGetImageName(t *testing.T) {
testCases := []struct {
container *Container
imageName string
}{
{&Container{}, ""},
{&Container{Image: "someImage:latest"}, "someImage"},
{&Container{Image: "someImage"}, "someImage"},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("GetImageName should return %s for %s", tc.imageName, tc.container.String()),
func(t *testing.T) {
imageName := tc.container.GetImageName()
assert.Equal(t, tc.imageName, imageName)
})
}
}

// TestSetupExecutionRoleFlag tests whether or not the container appropriately
// sets the flag for using execution roles
func TestSetupExecutionRoleFlag(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions agent/api/container/containertype.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ var stringToContainerType = map[string]ContainerType{
"EMPTY_HOST_VOLUME": ContainerEmptyHostVolume,
"CNI_PAUSE": ContainerCNIPause,
"NAMESPACE_PAUSE": ContainerNamespacePause,
"MANAGED_DAEMON": ContainerManagedDaemon,
}

// String converts the container type enum to a string
Expand Down
29 changes: 29 additions & 0 deletions agent/api/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -3687,3 +3687,32 @@ func (task *Task) HasActiveContainers() bool {
}
return false
}

// IsManagedDaemonTask will check if a task is a non-stopped managed daemon task
// TODO: Somehow track this on a task level (i.e. obtain the managed daemon image name from task arn and then find the corresponding container with the image name)
func (task *Task) IsManagedDaemonTask() (string, bool) {
task.lock.RLock()
defer task.lock.RUnlock()

// We'll want to obtain the last known non-stopped managed daemon task to be saved into our task engine.
// There can be an edge case where the task hasn't been progressed to RUNNING yet.
if !task.IsInternal || task.KnownStatusUnsafe.Terminal() {
return "", false
}

for _, c := range task.Containers {
if c.IsManagedDaemonContainer() {
imageName := c.GetImageName()
return imageName, true
Comment on lines +3705 to +3706
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a managed daemon in future has more than a single container we'll need to figure out which container is THE container. This might be an essential container, or else we'll need to enforce a one-daemon-container per daemon task rule. We might want to add a TODO to track this as a task-level field in future.
https://github.com/aws/amazon-ecs-agent/blob/master/agent/engine/daemonmanager/daemon_manager_linux.go#L131C3-L131C103
The daemons are currently very specifically focused on a single image as base.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

}
}
return "", false
}

func (task *Task) IsRunning() bool {
task.lock.RLock()
defer task.lock.RUnlock()
taskStatus := task.KnownStatusUnsafe

return taskStatus == apitaskstatus.TaskRunning
}
102 changes: 102 additions & 0 deletions agent/api/task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5278,3 +5278,105 @@ func TestRemoveVolumeIndexOutOfBounds(t *testing.T) {
task.RemoveVolume(-1)
assert.Equal(t, len(task.Volumes), 1)
}

func TestIsManagedDaemonTask(t *testing.T) {

testTask1 := &Task{
Containers: []*apicontainer.Container{
{
Type: apicontainer.ContainerManagedDaemon,
Image: "someImage:latest",
},
},
IsInternal: true,
KnownStatusUnsafe: apitaskstatus.TaskRunning,
}

testTask2 := &Task{
Containers: []*apicontainer.Container{
{
Type: apicontainer.ContainerNormal,
Image: "someImage",
},
{
Type: apicontainer.ContainerNormal,
Image: "someImage:latest",
},
},
IsInternal: false,
KnownStatusUnsafe: apitaskstatus.TaskRunning,
}

testTask3 := &Task{
Containers: []*apicontainer.Container{
{
Type: apicontainer.ContainerManagedDaemon,
Image: "someImage:latest",
},
},
IsInternal: true,
KnownStatusUnsafe: apitaskstatus.TaskStopped,
}

testTask4 := &Task{
Containers: []*apicontainer.Container{
{
Type: apicontainer.ContainerManagedDaemon,
Image: "someImage:latest",
},
},
IsInternal: true,
KnownStatusUnsafe: apitaskstatus.TaskCreated,
}

testTask5 := &Task{
Containers: []*apicontainer.Container{
{
Type: apicontainer.ContainerNormal,
Image: "someImage",
},
},
IsInternal: true,
KnownStatusUnsafe: apitaskstatus.TaskStopped,
}

testCases := []struct {
task *Task
internal bool
isManagedDaemon bool
}{
{
task: testTask1,
internal: true,
isManagedDaemon: true,
},
{
task: testTask2,
internal: false,
isManagedDaemon: false,
},
{
task: testTask3,
internal: true,
isManagedDaemon: false,
},
{
task: testTask4,
internal: true,
isManagedDaemon: true,
},
{
task: testTask5,
internal: true,
isManagedDaemon: false,
},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("IsManagedDaemonTask should return %t for %s", tc.isManagedDaemon, tc.task.String()),
func(t *testing.T) {
_, ok := tc.task.IsManagedDaemonTask()
assert.Equal(t, tc.isManagedDaemon, ok)
})
}
}
5 changes: 3 additions & 2 deletions agent/ebs/watcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ func (w *EBSWatcher) HandleEBSResourceAttachment(ebs *apiebs.ResourceAttachment)
}

// start EBS CSI Driver Managed Daemon
if runningCsiTask := w.taskEngine.GetDaemonTask(md.EbsCsiDriver); runningCsiTask != nil {
// We want to avoid creating a new CSI driver task if there's already one that's not been stopped.
if runningCsiTask := w.taskEngine.GetDaemonTask(md.EbsCsiDriver); runningCsiTask != nil && !runningCsiTask.GetKnownStatus().Terminal() {
fierlion marked this conversation as resolved.
Show resolved Hide resolved
log.Debugf("engine ebs CSI driver is running with taskID: %v", runningCsiTask.GetID())
} else {
if ebsCsiDaemonManager, ok := w.taskEngine.GetDaemonManagers()[md.EbsCsiDriver]; ok {
Expand Down Expand Up @@ -191,7 +192,7 @@ func (w *EBSWatcher) stageVolumeEBS(volID, deviceName string) error {
}
attachmentMountPath := ebsAttachment.GetAttachmentProperties(apiebs.SourceVolumeHostPathKey)
hostPath := filepath.Join(hostMountDir, attachmentMountPath)
filesystemType := ebsAttachment.GetAttachmentProperties(apiebs.FileSystemTypeName)
filesystemType := ebsAttachment.GetAttachmentProperties(apiebs.FileSystemKey)
// CSI NodeStage stub required fields
stubSecrets := make(map[string]string)
stubVolumeContext := make(map[string]string)
Expand Down
Loading