Skip to content

Commit

Permalink
Support for setting PIDs limit for ECS tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
sparrc committed Jul 7, 2023
1 parent e9ce7c7 commit bb0e738
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 18 deletions.
7 changes: 4 additions & 3 deletions agent/api/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,9 +363,10 @@ func (task *Task) PostUnmarshalTask(cfg *config.Config,

task.adjustForPlatform(cfg)

// TODO, add rudimentary plugin support and call any plugins that want to
// hook into this
if err := task.initializeCgroupResourceSpec(cfg.CgroupPath, cfg.CgroupCPUPeriod, resourceFields); err != nil {
// Initialize cgroup resource spec definition for later cgroup resource creation.
// This sets up the cgroup spec for cpu, memory, and pids limits for the task.
// Actual cgroup creation happens later.
if err := task.initializeCgroupResourceSpec(cfg.CgroupPath, cfg.CgroupCPUPeriod, cfg.TaskPidsLimit, resourceFields); err != nil {
logger.Error("Could not initialize resource", logger.Fields{
field.TaskID: task.GetID(),
field.Error: err,
Expand Down
3 changes: 3 additions & 0 deletions agent/api/task/task_attachment_handler_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
//go:build unit
// +build unit

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
Expand Down
14 changes: 11 additions & 3 deletions agent/api/task/task_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (task *Task) adjustForPlatform(cfg *config.Config) {
task.MemoryCPULimitsEnabled = cfg.TaskCPUMemLimit.Enabled()
}

func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, resourceFields *taskresource.ResourceFields) error {
func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, taskPidsLimit int, resourceFields *taskresource.ResourceFields) error {
if !task.MemoryCPULimitsEnabled {
if task.CPU > 0 || task.Memory > 0 {
// Client-side validation/warning if a task with task-level CPU/memory limits specified somehow lands on an instance
Expand All @@ -74,7 +74,7 @@ func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPerio
if err != nil {
return errors.Wrapf(err, "cgroup resource: unable to determine cgroup root for task")
}
resSpec, err := task.BuildLinuxResourceSpec(cGroupCPUPeriod)
resSpec, err := task.BuildLinuxResourceSpec(cGroupCPUPeriod, taskPidsLimit)
if err != nil {
return errors.Wrapf(err, "cgroup resource: unable to build resource spec for task")
}
Expand Down Expand Up @@ -122,7 +122,7 @@ func buildCgroupV2Root(taskID string) string {
}

// BuildLinuxResourceSpec returns a linuxResources object for the task cgroup
func (task *Task) BuildLinuxResourceSpec(cGroupCPUPeriod time.Duration) (specs.LinuxResources, error) {
func (task *Task) BuildLinuxResourceSpec(cGroupCPUPeriod time.Duration, taskPidsLimit int) (specs.LinuxResources, error) {
linuxResourceSpec := specs.LinuxResources{}

// If task level CPU limits are requested, set CPU quota + CPU period
Expand All @@ -148,6 +148,14 @@ func (task *Task) BuildLinuxResourceSpec(cGroupCPUPeriod time.Duration) (specs.L
linuxResourceSpec.Memory = &linuxMemorySpec
}

// Set task pids limit if set via ECS_TASK_PIDS_LIMIT env var
if taskPidsLimit > 0 {
pidsLimit := &specs.LinuxPids{
Limit: int64(taskPidsLimit),
}
linuxResourceSpec.Pids = pidsLimit
}

return linuxResourceSpec, nil
}

Expand Down
108 changes: 98 additions & 10 deletions agent/api/task/task_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,68 @@ func TestBuildLinuxResourceSpecCPUMem(t *testing.T) {
},
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod)
linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0)

assert.NoError(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
}

// BuildLinuxResourceSpec tested with pid limits passed in.
func TestBuildLinuxResourceSpecCPUMem_WithPidLimits(t *testing.T) {
taskMemoryLimit := int64(taskMemoryLimit)

task := &Task{
Arn: validTaskArn,
CPU: float64(taskVCPULimit),
Memory: taskMemoryLimit,
}

expectedTaskCPUPeriod := uint64(defaultCPUPeriod / time.Microsecond)
expectedTaskCPUQuota := int64(taskVCPULimit * float64(expectedTaskCPUPeriod))
expectedTaskMemory := taskMemoryLimit * bytesPerMegabyte
expectedLinuxResourceSpec := specs.LinuxResources{
CPU: &specs.LinuxCPU{
Quota: &expectedTaskCPUQuota,
Period: &expectedTaskCPUPeriod,
},
Memory: &specs.LinuxMemory{
Limit: &expectedTaskMemory,
},
Pids: &specs.LinuxPids{
Limit: int64(100),
},
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 100)

assert.NoError(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
}

// no pid limits expected when BuildLinuxResourceSpec receives an invalid value.
func TestBuildLinuxResourceSpecCPUMem_NegativeInvalidPidLimits(t *testing.T) {
taskMemoryLimit := int64(taskMemoryLimit)

task := &Task{
Arn: validTaskArn,
CPU: float64(taskVCPULimit),
Memory: taskMemoryLimit,
}

expectedTaskCPUPeriod := uint64(defaultCPUPeriod / time.Microsecond)
expectedTaskCPUQuota := int64(taskVCPULimit * float64(expectedTaskCPUPeriod))
expectedTaskMemory := taskMemoryLimit * bytesPerMegabyte
expectedLinuxResourceSpec := specs.LinuxResources{
CPU: &specs.LinuxCPU{
Quota: &expectedTaskCPUQuota,
Period: &expectedTaskCPUPeriod,
},
Memory: &specs.LinuxMemory{
Limit: &expectedTaskMemory,
},
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, -1)

assert.NoError(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
Expand All @@ -325,7 +386,7 @@ func TestBuildLinuxResourceSpecCPU(t *testing.T) {
},
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod)
linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0)

assert.NoError(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
Expand All @@ -340,7 +401,7 @@ func TestBuildLinuxResourceSpecIncreasedTaskCPULimit(t *testing.T) {
CPU: increasedTaskVCPULimit,
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod)
linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0)

expectedTaskCPUPeriod := uint64(defaultCPUPeriod / time.Microsecond)
expectedTaskCPUQuota := int64(increasedTaskVCPULimit * float64(expectedTaskCPUPeriod))
Expand Down Expand Up @@ -371,7 +432,34 @@ func TestBuildLinuxResourceSpecWithoutTaskCPULimits(t *testing.T) {
},
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod)
linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0)

assert.NoError(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
}

// TestBuildLinuxResourceSpecWithoutTaskCPULimits validates behavior of CPU Shares
// validate that pid limits are also inserted correctly
func TestBuildLinuxResourceSpecWithoutTaskCPULimits_WithPidLimits(t *testing.T) {
task := &Task{
Arn: validTaskArn,
Containers: []*apicontainer.Container{
{
Name: "C1",
},
},
}
expectedCPUShares := uint64(minimumCPUShare)
expectedLinuxResourceSpec := specs.LinuxResources{
CPU: &specs.LinuxCPU{
Shares: &expectedCPUShares,
},
Pids: &specs.LinuxPids{
Limit: int64(100),
},
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 100)

assert.NoError(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
Expand All @@ -395,7 +483,7 @@ func TestBuildLinuxResourceSpecWithoutTaskCPUWithContainerCPULimits(t *testing.T
},
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod)
linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0)

assert.NoError(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
Expand All @@ -420,7 +508,7 @@ func TestBuildLinuxResourceSpecWithoutTaskCPUWithLessThanMinimumContainerCPULimi
},
}

linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod)
linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0)

assert.NoError(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
Expand All @@ -443,7 +531,7 @@ func TestBuildLinuxResourceSpecInvalidMem(t *testing.T) {
}

expectedLinuxResourceSpec := specs.LinuxResources{}
linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod)
linuxResourceSpec, err := task.BuildLinuxResourceSpec(defaultCPUPeriod, 0)

assert.Error(t, err)
assert.EqualValues(t, expectedLinuxResourceSpec, linuxResourceSpec)
Expand Down Expand Up @@ -593,7 +681,7 @@ func TestInitCgroupResourceSpecHappyPath(t *testing.T) {
defer ctrl.Finish()
mockControl := mock_control.NewMockControl(ctrl)
mockIO := mock_ioutilwrapper.NewMockIOUtil(ctrl)
assert.NoError(t, task.initializeCgroupResourceSpec("cgroupPath", defaultCPUPeriod, &taskresource.ResourceFields{
assert.NoError(t, task.initializeCgroupResourceSpec("cgroupPath", defaultCPUPeriod, 0, &taskresource.ResourceFields{
Control: mockControl,
ResourceFieldsCommon: &taskresource.ResourceFieldsCommon{
IOUtil: mockIO,
Expand All @@ -617,7 +705,7 @@ func TestInitCgroupResourceSpecInvalidARN(t *testing.T) {
MemoryCPULimitsEnabled: true,
ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource),
}
assert.Error(t, task.initializeCgroupResourceSpec("", time.Millisecond, nil))
assert.Error(t, task.initializeCgroupResourceSpec("", time.Millisecond, 0, nil))
assert.Equal(t, 0, len(task.GetResources()))
assert.Equal(t, 0, len(task.Containers[0].TransitionDependenciesMap))
}
Expand All @@ -638,7 +726,7 @@ func TestInitCgroupResourceSpecInvalidMem(t *testing.T) {
MemoryCPULimitsEnabled: true,
ResourcesMapUnsafe: make(map[string][]taskresource.TaskResource),
}
assert.Error(t, task.initializeCgroupResourceSpec("", time.Millisecond, nil))
assert.Error(t, task.initializeCgroupResourceSpec("", time.Millisecond, 0, nil))
assert.Equal(t, 0, len(task.GetResources()))
assert.Equal(t, 0, len(task.Containers[0].TransitionDependenciesMap))
}
Expand Down
2 changes: 1 addition & 1 deletion agent/api/task/task_unsupported.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func (task *Task) adjustForPlatform(cfg *config.Config) {
task.MemoryCPULimitsEnabled = cfg.TaskCPUMemLimit.Enabled()
}

func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, resourceFields *taskresource.ResourceFields) error {
func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, taskPidsLimit int, resourceFields *taskresource.ResourceFields) error {
if !task.MemoryCPULimitsEnabled {
if task.CPU > 0 || task.Memory > 0 {
// Client-side validation/warning if a task with task-level CPU/memory limits specified somehow lands on an instance
Expand Down
2 changes: 1 addition & 1 deletion agent/api/task/task_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (task *Task) dockerCPUShares(containerCPU uint) int64 {
return int64(containerCPU)
}

func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, resourceFields *taskresource.ResourceFields) error {
func (task *Task) initializeCgroupResourceSpec(cgroupPath string, cGroupCPUPeriod time.Duration, taskPidsLimit int, resourceFields *taskresource.ResourceFields) error {
if !task.MemoryCPULimitsEnabled {
if task.CPU > 0 || task.Memory > 0 {
// Client-side validation/warning if a task with task-level CPU/memory limits specified somehow lands on an instance
Expand Down
1 change: 1 addition & 0 deletions agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,7 @@ func environmentConfig() (Config, error) {
ShouldExcludeIPv6PortBinding: parseBooleanDefaultTrueConfig("ECS_EXCLUDE_IPV6_PORTBINDING"),
WarmPoolsSupport: parseBooleanDefaultFalseConfig("ECS_WARM_POOLS_CHECK"),
DynamicHostPortRange: parseDynamicHostPortRange("ECS_DYNAMIC_HOST_PORT_RANGE"),
TaskPidsLimit: parseTaskPidsLimit(),
}, err
}

Expand Down
22 changes: 22 additions & 0 deletions agent/config/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,28 @@ func parseTaskMetadataThrottles() (int, int) {
return steadyStateRate, burstRate
}

func parseTaskPidsLimit() int {
var taskPidsLimit int
pidsLimitEnvVal := os.Getenv("ECS_TASK_PIDS_LIMIT")
if pidsLimitEnvVal == "" {
seelog.Debug("Environment variable empty: ECS_TASK_PIDS_LIMIT")
return 0
}

taskPidsLimit, err := strconv.Atoi(strings.TrimSpace(pidsLimitEnvVal))
if err != nil {
seelog.Warnf(`Invalid format for "ECS_TASK_PIDS_LIMIT", expected an integer but got [%v]: %v`, pidsLimitEnvVal, err)
return 0
}

if taskPidsLimit <= 0 {
seelog.Warnf(`Invalid value for "ECS_TASK_PIDS_LIMIT", expected integer greater than 0, but got [%v]`, taskPidsLimit)
return 0
}

return taskPidsLimit
}

func parseContainerInstanceTags(errs []error) (map[string]string, []error) {
var containerInstanceTags map[string]string
containerInstanceTagsConfigString := os.Getenv("ECS_CONTAINER_INSTANCE_TAGS")
Expand Down
56 changes: 56 additions & 0 deletions agent/config/parse_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package config

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestParseTaskPidsLimit(t *testing.T) {
t.Setenv("ECS_TASK_PIDS_LIMIT", "1")
assert.Equal(t, 1, parseTaskPidsLimit())
t.Setenv("ECS_TASK_PIDS_LIMIT", "10")
assert.Equal(t, 10, parseTaskPidsLimit())
t.Setenv("ECS_TASK_PIDS_LIMIT", "100")
assert.Equal(t, 100, parseTaskPidsLimit())
t.Setenv("ECS_TASK_PIDS_LIMIT", "10000")
assert.Equal(t, 10000, parseTaskPidsLimit())
t.Setenv("ECS_TASK_PIDS_LIMIT", "0")
assert.Equal(t, 0, parseTaskPidsLimit())
t.Setenv("ECS_TASK_PIDS_LIMIT", "-1")
assert.Equal(t, 0, parseTaskPidsLimit())
t.Setenv("ECS_TASK_PIDS_LIMIT", "foobar")
assert.Equal(t, 0, parseTaskPidsLimit())
t.Setenv("ECS_TASK_PIDS_LIMIT", "")
assert.Equal(t, 0, parseTaskPidsLimit())
}

func TestParseTaskPidsLimit_Unset(t *testing.T) {
assert.Equal(t, 0, parseTaskPidsLimit())
}

func TestParseContainerInstanceTags(t *testing.T) {
// empty
t.Setenv("ECS_CONTAINER_INSTANCE_TAGS", "")
var expected, actual map[string]string
expectedErrs := []error{}
actualErrs := []error{}
actual, actualErrs = parseContainerInstanceTags(actualErrs)
assert.Equal(t, expected, actual)
assert.Equal(t, expectedErrs, actualErrs)
// with valid values
t.Setenv("ECS_CONTAINER_INSTANCE_TAGS", `{"foo":"bar","baz":"bin","num":"7"}`)
expected = map[string]string{"baz": "bin", "foo": "bar", "num": "7"}
expectedErrs = []error{}
actual, actualErrs = parseContainerInstanceTags(actualErrs)
assert.Equal(t, expected, actual)
assert.Equal(t, expectedErrs, actualErrs)
// with invalid values
t.Setenv("ECS_CONTAINER_INSTANCE_TAGS", `{"foo":"bar","baz":"bin,"num":"7"}`) // missing "
var expectedInvalid map[string]string
expectedErrs = []error{fmt.Errorf("Invalid format for ECS_CONTAINER_INSTANCE_TAGS. Expected a json hash: invalid character 'n' after object key:value pair")}
actual, actualErrs = parseContainerInstanceTags(actualErrs)
assert.Equal(t, expectedInvalid, actual)
assert.Equal(t, expectedErrs, actualErrs)
}
6 changes: 6 additions & 0 deletions agent/config/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -370,4 +370,10 @@ type Config struct {
// uses to assign host ports from, for a container port range mapping.
// This defaults to the platform specific ephemeral host port range
DynamicHostPortRange string

// TaskPidsLimit specifies the per-task pids limit cgroup setting for each
// task launched on this container instance. This setting maps to the pids.max
// cgroup setting at the ECS task level.
// see https://www.kernel.org/doc/html/latest/admin-guide/cgroup-v2.html#pid
TaskPidsLimit int
}

0 comments on commit bb0e738

Please sign in to comment.