From d3f5bdcc04053dfb929b2cf88904329ed48970d5 Mon Sep 17 00:00:00 2001 From: Cameron Sparr Date: Fri, 7 Jul 2023 07:41:30 -0700 Subject: [PATCH] Support for setting PIDs limit for ECS tasks --- agent/api/task/task.go | 7 +- .../api/task/task_attachment_handler_test.go | 3 + agent/api/task/task_linux.go | 14 ++- agent/api/task/task_linux_test.go | 108 ++++++++++++++++-- agent/api/task/task_unsupported.go | 2 +- agent/api/task/task_windows.go | 2 +- agent/config/config.go | 1 + agent/config/parse_linux.go | 24 ++++ agent/config/parse_linux_test.go | 29 +++++ agent/config/parse_test.go | 33 ++++++ agent/config/parse_windows.go | 10 ++ agent/config/parse_windows_test.go | 23 ++++ agent/config/types.go | 6 + 13 files changed, 244 insertions(+), 18 deletions(-) create mode 100644 agent/config/parse_test.go diff --git a/agent/api/task/task.go b/agent/api/task/task.go index 10f2c9d5670..c623542e79a 100644 --- a/agent/api/task/task.go +++ b/agent/api/task/task.go @@ -363,9 +363,10 @@ func (task *Task) PostUnmarshalTask(cfg *config.Config, task.adjustForPlatform(cfg) - // TODO, add rudimentary plugin support and call any plugins that want to - // hook into this - if err := task.initializeCgroupResourceSpec(cfg.CgroupPath, cfg.CgroupCPUPeriod, resourceFields); err != nil { + // Initialize cgroup resource spec definition for later cgroup resource creation. + // This sets up the cgroup spec for cpu, memory, and pids limits for the task. + // Actual cgroup creation happens later. + if err := task.initializeCgroupResourceSpec(cfg.CgroupPath, cfg.CgroupCPUPeriod, cfg.TaskPidsLimit, resourceFields); err != nil { logger.Error("Could not initialize resource", logger.Fields{ field.TaskID: task.GetID(), field.Error: err, diff --git a/agent/api/task/task_attachment_handler_test.go b/agent/api/task/task_attachment_handler_test.go index 137a58f7c15..15e9e8c8bed 100644 --- a/agent/api/task/task_attachment_handler_test.go +++ b/agent/api/task/task_attachment_handler_test.go @@ -1,3 +1,6 @@ +//go:build unit +// +build unit + // Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"). You may diff --git a/agent/api/task/task_linux.go b/agent/api/task/task_linux.go index a2c38e63b53..04d5ac7a884 100644 --- a/agent/api/task/task_linux.go +++ b/agent/api/task/task_linux.go @@ -58,7 +58,7 @@ func (task *Task) adjustForPlatform(cfg *config.Config) { task.MemoryCPULimitsEnabled = cfg.TaskCPUMemLimit.Enabled() } -func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, resourceFields *taskresource.ResourceFields) error { +func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, taskPidsLimit int, resourceFields *taskresource.ResourceFields) error { if !task.MemoryCPULimitsEnabled { if task.CPU > 0 || task.Memory > 0 { // Client-side validation/warning if a task with task-level CPU/memory limits specified somehow lands on an instance @@ -74,7 +74,7 @@ func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPerio if err != nil { return errors.Wrapf(err, "cgroup resource: unable to determine cgroup root for task") } - resSpec, err := task.BuildLinuxResourceSpec(cGroupCPUPeriod) + resSpec, err := task.BuildLinuxResourceSpec(cGroupCPUPeriod, taskPidsLimit) if err != nil { return errors.Wrapf(err, "cgroup resource: unable to build resource spec for task") } @@ -122,7 +122,7 @@ func buildCgroupV2Root(taskID string) string { } // BuildLinuxResourceSpec returns a linuxResources object for the task cgroup -func (task *Task) BuildLinuxResourceSpec(cGroupCPUPeriod time.Duration) (specs.LinuxResources, error) { +func (task *Task) BuildLinuxResourceSpec(cGroupCPUPeriod time.Duration, taskPidsLimit int) (specs.LinuxResources, error) { linuxResourceSpec := specs.LinuxResources{} // If task level CPU limits are requested, set CPU quota + CPU period @@ -148,6 +148,14 @@ func (task *Task) BuildLinuxResourceSpec(cGroupCPUPeriod time.Duration) (specs.L linuxResourceSpec.Memory = &linuxMemorySpec } + // Set task pids limit if set via ECS_TASK_PIDS_LIMIT env var + if taskPidsLimit > 0 { + pidsLimit := &specs.LinuxPids{ + Limit: int64(taskPidsLimit), + } + linuxResourceSpec.Pids = pidsLimit + } + return linuxResourceSpec, nil } diff --git a/agent/api/task/task_linux_test.go b/agent/api/task/task_linux_test.go index 45691b97537..fdf8c9f2b02 100644 --- a/agent/api/task/task_linux_test.go +++ b/agent/api/task/task_linux_test.go @@ -303,7 +303,68 @@ func TestBuildLinuxResourceSpecCPUMem(t *testing.T) { }, } - linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod) + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0) + + assert.NoError(t, err) + assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) +} + +// BuildLinuxResourceSpec tested with pid limits passed in. +func TestBuildLinuxResourceSpecCPUMem_WithPidLimits(t *testing.T) { + taskMemoryLimit := int64(taskMemoryLimit) + + task := &Task{ + Arn: validTaskArn, + CPU: float64(taskVCPULimit), + Memory: taskMemoryLimit, + } + + expectedTaskCPUPeriod := uint64(defaultCPUPeriod / time.Microsecond) + expectedTaskCPUQuota := int64(taskVCPULimit * float64(expectedTaskCPUPeriod)) + expectedTaskMemory := taskMemoryLimit * bytesPerMegabyte + expectedLinuxResourceSpec := specs.LinuxResources{ + CPU: &specs.LinuxCPU{ + Quota: &expectedTaskCPUQuota, + Period: &expectedTaskCPUPeriod, + }, + Memory: &specs.LinuxMemory{ + Limit: &expectedTaskMemory, + }, + Pids: &specs.LinuxPids{ + Limit: int64(100), + }, + } + + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 100) + + assert.NoError(t, err) + assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) +} + +// no pid limits expected when BuildLinuxResourceSpec receives an invalid value. +func TestBuildLinuxResourceSpecCPUMem_NegativeInvalidPidLimits(t *testing.T) { + taskMemoryLimit := int64(taskMemoryLimit) + + task := &Task{ + Arn: validTaskArn, + CPU: float64(taskVCPULimit), + Memory: taskMemoryLimit, + } + + expectedTaskCPUPeriod := uint64(defaultCPUPeriod / time.Microsecond) + expectedTaskCPUQuota := int64(taskVCPULimit * float64(expectedTaskCPUPeriod)) + expectedTaskMemory := taskMemoryLimit * bytesPerMegabyte + expectedLinuxResourceSpec := specs.LinuxResources{ + CPU: &specs.LinuxCPU{ + Quota: &expectedTaskCPUQuota, + Period: &expectedTaskCPUPeriod, + }, + Memory: &specs.LinuxMemory{ + Limit: &expectedTaskMemory, + }, + } + + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, -1) assert.NoError(t, err) assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) @@ -325,7 +386,7 @@ func TestBuildLinuxResourceSpecCPU(t *testing.T) { }, } - linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod) + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0) assert.NoError(t, err) assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) @@ -340,7 +401,7 @@ func TestBuildLinuxResourceSpecIncreasedTaskCPULimit(t *testing.T) { CPU: increasedTaskVCPULimit, } - linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod) + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0) expectedTaskCPUPeriod := uint64(defaultCPUPeriod / time.Microsecond) expectedTaskCPUQuota := int64(increasedTaskVCPULimit * float64(expectedTaskCPUPeriod)) @@ -371,7 +432,34 @@ func TestBuildLinuxResourceSpecWithoutTaskCPULimits(t *testing.T) { }, } - linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod) + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0) + + assert.NoError(t, err) + assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) +} + +// TestBuildLinuxResourceSpecWithoutTaskCPULimits validates behavior of CPU Shares +// validate that pid limits are also inserted correctly +func TestBuildLinuxResourceSpecWithoutTaskCPULimits_WithPidLimits(t *testing.T) { + task := &Task{ + Arn: validTaskArn, + Containers: []*apicontainer.Container{ + { + Name: "C1", + }, + }, + } + expectedCPUShares := uint64(minimumCPUShare) + expectedLinuxResourceSpec := specs.LinuxResources{ + CPU: &specs.LinuxCPU{ + Shares: &expectedCPUShares, + }, + Pids: &specs.LinuxPids{ + Limit: int64(100), + }, + } + + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 100) assert.NoError(t, err) assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) @@ -395,7 +483,7 @@ func TestBuildLinuxResourceSpecWithoutTaskCPUWithContainerCPULimits(t *testing.T }, } - linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod) + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0) assert.NoError(t, err) assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) @@ -420,7 +508,7 @@ func TestBuildLinuxResourceSpecWithoutTaskCPUWithLessThanMinimumContainerCPULimi }, } - linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod) + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0) assert.NoError(t, err) assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) @@ -443,7 +531,7 @@ func TestBuildLinuxResourceSpecInvalidMem(t *testing.T) { } expectedLinuxResourceSpec := specs.LinuxResources{} - linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod) + linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0) assert.Error(t, err) assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec) @@ -593,7 +681,7 @@ func TestInitCgroupResourceSpecHappyPath(t *testing.T) { defer ctrl.Finish() mockControl := mock_control.NewMockControl(ctrl) mockIO := mock_ioutilwrapper.NewMockIOUtil(ctrl) - assert.NoError(t, task.initializeCgroupResourceSpec("cgroupPath", defaultCPUPeriod, &taskresource.ResourceFields{ + assert.NoError(t, task.initializeCgroupResourceSpec("cgroupPath", defaultCPUPeriod, 0, &taskresource.ResourceFields{ Control: mockControl, ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{ IOUtil: mockIO, @@ -617,7 +705,7 @@ func TestInitCgroupResourceSpecInvalidARN(t *testing.T) { MemoryCPULimitsEnabled: true, ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource), } - assert.Error(t, task.initializeCgroupResourceSpec("", time.Millisecond, nil)) + assert.Error(t, task.initializeCgroupResourceSpec("", time.Millisecond, 0, nil)) assert.Equal(t, 0, len(task.GetResources())) assert.Equal(t, 0, len(task.Containers[0].TransitionDependenciesMap)) } @@ -638,7 +726,7 @@ func TestInitCgroupResourceSpecInvalidMem(t *testing.T) { MemoryCPULimitsEnabled: true, ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource), } - assert.Error(t, task.initializeCgroupResourceSpec("", time.Millisecond, nil)) + assert.Error(t, task.initializeCgroupResourceSpec("", time.Millisecond, 0, nil)) assert.Equal(t, 0, len(task.GetResources())) assert.Equal(t, 0, len(task.Containers[0].TransitionDependenciesMap)) } diff --git a/agent/api/task/task_unsupported.go b/agent/api/task/task_unsupported.go index 0f0f801ae68..c4407e21e4d 100644 --- a/agent/api/task/task_unsupported.go +++ b/agent/api/task/task_unsupported.go @@ -44,7 +44,7 @@ func (task *Task) adjustForPlatform(cfg *config.Config) { task.MemoryCPULimitsEnabled = cfg.TaskCPUMemLimit.Enabled() } -func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, resourceFields *taskresource.ResourceFields) error { +func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, taskPidsLimit int, resourceFields *taskresource.ResourceFields) error { if !task.MemoryCPULimitsEnabled { if task.CPU > 0 || task.Memory > 0 { // Client-side validation/warning if a task with task-level CPU/memory limits specified somehow lands on an instance diff --git a/agent/api/task/task_windows.go b/agent/api/task/task_windows.go index ea63016b4f9..fddd09de04d 100644 --- a/agent/api/task/task_windows.go +++ b/agent/api/task/task_windows.go @@ -133,7 +133,7 @@ func (task *Task) dockerCPUShares(containerCPU uint) int64 { return int64(containerCPU) } -func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, resourceFields *taskresource.ResourceFields) error { +func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, taskPidsLimit int, resourceFields *taskresource.ResourceFields) error { if !task.MemoryCPULimitsEnabled { if task.CPU > 0 || task.Memory > 0 { // Client-side validation/warning if a task with task-level CPU/memory limits specified somehow lands on an instance diff --git a/agent/config/config.go b/agent/config/config.go index 34339bdf4d2..be54c2ec1ef 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -596,6 +596,7 @@ func environmentConfig() (Config, error) { ShouldExcludeIPv6PortBinding: parseBooleanDefaultTrueConfig("ECS_EXCLUDE_IPV6_PORTBINDING"), WarmPoolsSupport: parseBooleanDefaultFalseConfig("ECS_WARM_POOLS_CHECK"), DynamicHostPortRange: parseDynamicHostPortRange("ECS_DYNAMIC_HOST_PORT_RANGE"), + TaskPidsLimit: parseTaskPidsLimit(), }, err } diff --git a/agent/config/parse_linux.go b/agent/config/parse_linux.go index 5ad45ea7ca1..927c3ef95b6 100644 --- a/agent/config/parse_linux.go +++ b/agent/config/parse_linux.go @@ -110,3 +110,27 @@ var IsWindows2016 = func() (bool, error) { func GetOSFamily() string { return strings.ToUpper(OSType) } + +func parseTaskPidsLimit() int { + var taskPidsLimit int + pidsLimitEnvVal := os.Getenv("ECS_TASK_PIDS_LIMIT") + if pidsLimitEnvVal == "" { + seelog.Debug("Environment variable empty: ECS_TASK_PIDS_LIMIT") + return 0 + } + + taskPidsLimit, err := strconv.Atoi(strings.TrimSpace(pidsLimitEnvVal)) + if err != nil { + seelog.Warnf(`Invalid format for "ECS_TASK_PIDS_LIMIT", expected an integer but got [%v]: %v`, pidsLimitEnvVal, err) + return 0 + } + + // 4194304 is a defacto limit set by runc on Amazon Linux (4*1024*1024), so + // we should use the same to avoid runtime container failures. + if taskPidsLimit <= 0 || taskPidsLimit > 4194304 { + seelog.Warnf(`Invalid value for "ECS_TASK_PIDS_LIMIT", expected integer greater than 0 and less than 4194305, but got [%v]`, taskPidsLimit) + return 0 + } + + return taskPidsLimit +} diff --git a/agent/config/parse_linux_test.go b/agent/config/parse_linux_test.go index 1c6715852ac..62031dc87ab 100644 --- a/agent/config/parse_linux_test.go +++ b/agent/config/parse_linux_test.go @@ -80,3 +80,32 @@ func TestSkipDomainLessCheckParseGMSACapability(t *testing.T) { assert.True(t, parseGMSADomainlessCapability().Enabled()) } + +func TestParseTaskPidsLimit(t *testing.T) { + t.Setenv("ECS_TASK_PIDS_LIMIT", "1") + assert.Equal(t, 1, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "10") + assert.Equal(t, 10, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "100") + assert.Equal(t, 100, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "10000") + assert.Equal(t, 10000, parseTaskPidsLimit()) + // test the upper limit minus 1 + t.Setenv("ECS_TASK_PIDS_LIMIT", "4194304") + assert.Equal(t, 4194304, parseTaskPidsLimit()) + // test the upper limit + t.Setenv("ECS_TASK_PIDS_LIMIT", "4194305") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "0") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "-1") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "foobar") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "") + assert.Equal(t, 0, parseTaskPidsLimit()) +} + +func TestParseTaskPidsLimit_Unset(t *testing.T) { + assert.Equal(t, 0, parseTaskPidsLimit()) +} diff --git a/agent/config/parse_test.go b/agent/config/parse_test.go new file mode 100644 index 00000000000..7f8065eda86 --- /dev/null +++ b/agent/config/parse_test.go @@ -0,0 +1,33 @@ +package config + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseContainerInstanceTags(t *testing.T) { + // empty + t.Setenv("ECS_CONTAINER_INSTANCE_TAGS", "") + var expected, actual map[string]string + expectedErrs := []error{} + actualErrs := []error{} + actual, actualErrs = parseContainerInstanceTags(actualErrs) + assert.Equal(t, expected, actual) + assert.Equal(t, expectedErrs, actualErrs) + // with valid values + t.Setenv("ECS_CONTAINER_INSTANCE_TAGS", `{"foo":"bar","baz":"bin","num":"7"}`) + expected = map[string]string{"baz": "bin", "foo": "bar", "num": "7"} + expectedErrs = []error{} + actual, actualErrs = parseContainerInstanceTags(actualErrs) + assert.Equal(t, expected, actual) + assert.Equal(t, expectedErrs, actualErrs) + // with invalid values + t.Setenv("ECS_CONTAINER_INSTANCE_TAGS", `{"foo":"bar","baz":"bin,"num":"7"}`) // missing " + var expectedInvalid map[string]string + expectedErrs = []error{fmt.Errorf("Invalid format for ECS_CONTAINER_INSTANCE_TAGS. Expected a json hash: invalid character 'n' after object key:value pair")} + actual, actualErrs = parseContainerInstanceTags(actualErrs) + assert.Equal(t, expectedInvalid, actual) + assert.Equal(t, expectedErrs, actualErrs) +} diff --git a/agent/config/parse_windows.go b/agent/config/parse_windows.go index bddfd782a1f..2cb220fce04 100644 --- a/agent/config/parse_windows.go +++ b/agent/config/parse_windows.go @@ -182,3 +182,13 @@ func isDomainlessGmsaPluginInstalled() (bool, error) { return false, nil } + +func parseTaskPidsLimit() int { + pidsLimitEnvVal := os.Getenv("ECS_TASK_PIDS_LIMIT") + if pidsLimitEnvVal == "" { + seelog.Debug("Environment variable empty: ECS_TASK_PIDS_LIMIT") + return 0 + } + seelog.Warnf(`"ECS_TASK_PIDS_LIMIT" is not supported on windows`) + return 0 +} diff --git a/agent/config/parse_windows_test.go b/agent/config/parse_windows_test.go index a94d0bb8897..0f1a166614d 100644 --- a/agent/config/parse_windows_test.go +++ b/agent/config/parse_windows_test.go @@ -119,3 +119,26 @@ func TestParseDomainlessgMSACapabilityTrue(t *testing.T) { assert.True(t, parseGMSADomainlessCapability().Enabled()) } + +func TestParseTaskPidsLimit(t *testing.T) { + t.Setenv("ECS_TASK_PIDS_LIMIT", "1") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "10") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "100") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "10000") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "0") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "-1") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "foobar") + assert.Equal(t, 0, parseTaskPidsLimit()) + t.Setenv("ECS_TASK_PIDS_LIMIT", "") + assert.Equal(t, 0, parseTaskPidsLimit()) +} + +func TestParseTaskPidsLimit_Unset(t *testing.T) { + assert.Equal(t, 0, parseTaskPidsLimit()) +} diff --git a/agent/config/types.go b/agent/config/types.go index 620fecc92d3..ced906727c5 100644 --- a/agent/config/types.go +++ b/agent/config/types.go @@ -370,4 +370,10 @@ type Config struct { // uses to assign host ports from, for a container port range mapping. // This defaults to the platform specific ephemeral host port range DynamicHostPortRange string + + // TaskPidsLimit specifies the per-task pids limit cgroup setting for each + // task launched on this container instance. This setting maps to the pids.max + // cgroup setting at the ECS task level. + // see https://www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html#pid + TaskPidsLimit int }