diff --git a/agent/api/container/container.go b/agent/api/container/container.go index db6ce4ac659..3289caef3f5 100644 --- a/agent/api/container/container.go +++ b/agent/api/container/container.go @@ -318,7 +318,7 @@ type Container struct { labels map[string]string - // hasPortRange is set to true when the container has at least 1 port range requested. + // ContainerHasPortRange is set to true when the container has at least 1 port range requested. ContainerHasPortRange bool // ContainerPortSet is a set of singular container ports that don't belong to a containerPortRange request ContainerPortSet map[int]struct{} @@ -1374,7 +1374,7 @@ func (c *Container) SetContainerHasPortRange(containerHasPortRange bool) { c.ContainerHasPortRange = containerHasPortRange } -func (c *Container) GetContainerHasPortRange() bool { +func (c *Container) HasPortRange() bool { c.lock.RLock() defer c.lock.RUnlock() return c.ContainerHasPortRange diff --git a/agent/api/ecsclient/client.go b/agent/api/ecsclient/client.go index 10d32422c72..54b846810d8 100644 --- a/agent/api/ecsclient/client.go +++ b/agent/api/ecsclient/client.go @@ -20,8 +20,6 @@ import ( "strings" "time" - "github.com/aws/amazon-ecs-agent/agent/logger" - "github.com/aws/amazon-ecs-agent/agent/api" apicontainerstatus "github.com/aws/amazon-ecs-agent/agent/api/container/status" apierrors "github.com/aws/amazon-ecs-agent/agent/api/errors" @@ -30,12 +28,15 @@ import ( "github.com/aws/amazon-ecs-agent/agent/ec2" "github.com/aws/amazon-ecs-agent/agent/ecs_client/model/ecs" "github.com/aws/amazon-ecs-agent/agent/httpclient" + "github.com/aws/amazon-ecs-agent/agent/logger" "github.com/aws/amazon-ecs-agent/agent/utils" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/session" "github.com/cihub/seelog" "github.com/docker/docker/pkg/system" + "github.com/docker/go-connections/nat" ) const ( @@ -49,6 +50,11 @@ const ( osTypeAttrName = "ecs.os-type" osFamilyAttrName = "ecs.os-family" RoundtripTimeout = 5 * time.Second + // networkModeBridge specifies the bridge network mode. + networkModeBridge = "bridge" + // ecsMaxNetworkBindingsLength is the maximum length of the ecs.NetworkBindings list sent as part of the + // container state change payload. Currently, this is enforced only when containerPortRanges are requested. + ecsMaxNetworkBindingsLength = 100 ) // APIECSClient implements ECSClient @@ -419,7 +425,12 @@ func (client *APIECSClient) SubmitTaskStateChange(change api.TaskStateChange) er containerEvents := make([]*ecs.ContainerStateChange, len(change.Containers)) for i, containerEvent := range change.Containers { - containerEvents[i] = client.buildContainerStateChangePayload(containerEvent, client.config.ShouldExcludeIPv6PortBinding.Enabled()) + payload, err := client.buildContainerStateChangePayload(containerEvent, client.config.ShouldExcludeIPv6PortBinding.Enabled()) + if err != nil { + seelog.Errorf("Could not submit task state change: [%s]: %v", change.String(), err) + return err + } + containerEvents[i] = payload } req.Containers = containerEvents @@ -460,7 +471,7 @@ func (client *APIECSClient) buildManagedAgentStateChangePayload(change api.Manag } } -func (client *APIECSClient) buildContainerStateChangePayload(change api.ContainerStateChange, shouldExcludeIPv6PortBinding bool) *ecs.ContainerStateChange { +func (client *APIECSClient) buildContainerStateChangePayload(change api.ContainerStateChange, shouldExcludeIPv6PortBinding bool) (*ecs.ContainerStateChange, error) { statechange := &ecs.ContainerStateChange{ ContainerName: aws.String(change.ContainerName), } @@ -481,7 +492,7 @@ func (client *APIECSClient) buildContainerStateChangePayload(change api.Containe if status != apicontainerstatus.ContainerStopped && status != apicontainerstatus.ContainerRunning { seelog.Warnf("Not submitting unsupported upstream container state %s for container %s in task %s", status.String(), change.ContainerName, change.TaskArn) - return nil + return nil, nil } stat := change.Status.String() if stat == "DEAD" { @@ -494,7 +505,42 @@ func (client *APIECSClient) buildContainerStateChangePayload(change api.Containe statechange.ExitCode = aws.Int64(exitCode) } + networkBindings := getNetworkBindings(change, shouldExcludeIPv6PortBinding) + // we enforce a limit on the no. of network bindings for containers with at-least 1 port range requested. + // this limit is enforced by ECS, and we fail early and don't call SubmitContainerStateChange. + if change.Container.HasPortRange() && len(networkBindings) > ecsMaxNetworkBindingsLength { + return nil, fmt.Errorf("no. of network bindings %v is more than the maximum supported no. %v, "+ + "container: %s "+"task: %s", len(networkBindings), ecsMaxNetworkBindingsLength, change.ContainerName, change.TaskArn) + } + statechange.NetworkBindings = networkBindings + + return statechange, nil +} + +// ProtocolBindIP used to store protocol and bindIP information associated to a particular host port +type ProtocolBindIP struct { + protocol string + bindIP string +} + +// getNetworkBindings returns the list of networkingBindings, sent to ECS as part of the container state change payload +func getNetworkBindings(change api.ContainerStateChange, shouldExcludeIPv6PortBinding bool) []*ecs.NetworkBinding { networkBindings := []*ecs.NetworkBinding{} + // we return network bindings for bridge network mode tasks only + if change.Container.GetNetworkMode() != networkModeBridge { + return networkBindings + } + // hostPortToProtocolBindIPMap is a map to store protocol and bindIP information associated to host ports + // that belong to a range. This is used in case when there are multiple protocol/bindIP combinations associated to a + // port binding. example: when both IPv4 and IPv6 bindIPs are populated by docker and shouldExcludeIPv6PortBinding is false. + hostPortToProtocolBindIPMap := map[int64][]ProtocolBindIP{} + + // ContainerPortSet consists of singular ports, and ports that belong to a range, but for which we were not able to + // find contiguous host ports and ask docker to pick instead. + containerPortSet := change.Container.GetContainerPortSet() + // each entry in the ContainerPortRangeMap implies that we found a contiguous host port range for the same + containerPortRangeMap := change.Container.GetContainerPortRangeMap() + for _, binding := range change.PortBindings { if binding.BindIP == "::" && shouldExcludeIPv6PortBinding { seelog.Debugf("Exclude IPv6 port binding %v for container %s in task %s", binding, change.ContainerName, change.TaskArn) @@ -506,24 +552,54 @@ func (client *APIECSClient) buildContainerStateChangePayload(change api.Containe bindIP := binding.BindIP protocol := binding.Protocol.String() - networkBindings = append(networkBindings, &ecs.NetworkBinding{ - BindIP: aws.String(bindIP), - ContainerPort: aws.Int64(containerPort), - HostPort: aws.Int64(hostPort), - Protocol: aws.String(protocol), - }) + // create network binding for each containerPort that exists in the singular ContainerPortSet + // for container ports that belong to a range, we'll have 1 consolidated network binding for the range + if _, ok := containerPortSet[int(containerPort)]; ok { + networkBindings = append(networkBindings, &ecs.NetworkBinding{ + BindIP: aws.String(bindIP), + ContainerPort: aws.Int64(containerPort), + HostPort: aws.Int64(hostPort), + Protocol: aws.String(protocol), + }) + } else { + // populate hostPortToProtocolBindIPMap – this is used below when we construct network binding for ranges. + hostPortToProtocolBindIPMap[hostPort] = append(hostPortToProtocolBindIPMap[hostPort], + ProtocolBindIP{ + protocol: protocol, + bindIP: bindIP, + }) + } } - statechange.NetworkBindings = networkBindings - return statechange + for containerPortRange, hostPortRange := range containerPortRangeMap { + // we check for protocol and bindIP information associated to any one of the host ports from the hostPortRange, + // all ports belonging to the same range share this information. + hostPort, _, _ := nat.ParsePortRangeToInt(hostPortRange) + if val, ok := hostPortToProtocolBindIPMap[int64(hostPort)]; ok { + for _, v := range val { + networkBindings = append(networkBindings, &ecs.NetworkBinding{ + BindIP: aws.String(v.bindIP), + ContainerPortRange: aws.String(containerPortRange), + HostPortRange: aws.String(hostPortRange), + Protocol: aws.String(v.protocol), + }) + } + } + } + + return networkBindings } func (client *APIECSClient) SubmitContainerStateChange(change api.ContainerStateChange) error { - pl := client.buildContainerStateChangePayload(change, client.config.ShouldExcludeIPv6PortBinding.Enabled()) - if pl == nil { + pl, err := client.buildContainerStateChangePayload(change, client.config.ShouldExcludeIPv6PortBinding.Enabled()) + if err != nil { + seelog.Errorf("Could not build container state change payload: [%s]: %v", change.String(), err) + return err + } else if pl == nil { return nil } - _, err := client.submitStateChangeClient.SubmitContainerStateChange(&ecs.SubmitContainerStateChangeInput{ + + _, err = client.submitStateChangeClient.SubmitContainerStateChange(&ecs.SubmitContainerStateChangeInput{ Cluster: aws.String(client.config.Cluster), ContainerName: aws.String(change.ContainerName), ExitCode: pl.ExitCode, diff --git a/agent/api/ecsclient/client_test.go b/agent/api/ecsclient/client_test.go index 021dc16be29..47393810188 100644 --- a/agent/api/ecsclient/client_test.go +++ b/agent/api/ecsclient/client_test.go @@ -50,6 +50,7 @@ const ( iid = "instanceIdentityDocument" iidSignature = "signature" registrationToken = "clientToken" + testNetworkName = "bridge" ) var ( @@ -196,6 +197,12 @@ func TestSubmitContainerStateChange(t *testing.T) { HostPort: int64ptr(intptr(4)), Protocol: strptr("udp"), }, + { + BindIP: strptr("5.6.7.8"), + ContainerPortRange: strptr("11-12"), + HostPortRange: strptr("11-12"), + Protocol: strptr("udp"), + }, }, }, }) @@ -204,6 +211,18 @@ func TestSubmitContainerStateChange(t *testing.T) { ContainerName: "cont", RuntimeID: "runtime id", Status: apicontainerstatus.ContainerRunning, + Container: &apicontainer.Container{ + ContainerArn: "arn", + NetworkModeUnsafe: testNetworkName, + ContainerHasPortRange: true, + ContainerPortSet: map[int]struct{}{ + 1: {}, + 3: {}, + }, + ContainerPortRangeMap: map[string]string{ + "11-12": "11-12", + }, + }, PortBindings: []apicontainer.PortBinding{ { BindIP: "1.2.3.4", @@ -216,6 +235,18 @@ func TestSubmitContainerStateChange(t *testing.T) { HostPort: 4, Protocol: apicontainer.TransportProtocolUDP, }, + { + BindIP: "5.6.7.8", + ContainerPort: aws.Uint16(11), + HostPort: 11, + Protocol: apicontainer.TransportProtocolUDP, + }, + { + BindIP: "5.6.7.8", + ContainerPort: aws.Uint16(12), + HostPort: 12, + Protocol: apicontainer.TransportProtocolUDP, + }, }, }) if err != nil { @@ -232,21 +263,14 @@ func TestSubmitContainerStateChangeFull(t *testing.T) { mockSubmitStateClient.EXPECT().SubmitContainerStateChange(&containerSubmitInputMatcher{ ecs.SubmitContainerStateChangeInput{ - Cluster: strptr(configuredCluster), - Task: strptr("arn"), - ContainerName: strptr("cont"), - RuntimeId: strptr("runtime id"), - Status: strptr("STOPPED"), - ExitCode: int64ptr(&exitCode), - Reason: strptr(reason), - NetworkBindings: []*ecs.NetworkBinding{ - { - BindIP: strptr(""), - ContainerPort: int64ptr(intptr(0)), - HostPort: int64ptr(intptr(0)), - Protocol: strptr("tcp"), - }, - }, + Cluster: strptr(configuredCluster), + Task: strptr("arn"), + ContainerName: strptr("cont"), + RuntimeId: strptr("runtime id"), + Status: strptr("STOPPED"), + ExitCode: int64ptr(&exitCode), + Reason: strptr(reason), + NetworkBindings: []*ecs.NetworkBinding{}, }, }) err := client.SubmitContainerStateChange(api.ContainerStateChange{ @@ -256,6 +280,9 @@ func TestSubmitContainerStateChangeFull(t *testing.T) { Status: apicontainerstatus.ContainerStopped, ExitCode: &exitCode, Reason: reason, + Container: &apicontainer.Container{ + NetworkModeUnsafe: testNetworkName, + }, PortBindings: []apicontainer.PortBinding{ {}, }, @@ -286,9 +313,12 @@ func TestSubmitContainerStateChangeReason(t *testing.T) { err := client.SubmitContainerStateChange(api.ContainerStateChange{ TaskArn: "arn", ContainerName: "cont", - Status: apicontainerstatus.ContainerStopped, - ExitCode: &exitCode, - Reason: reason, + Container: &apicontainer.Container{ + NetworkModeUnsafe: testNetworkName, + }, + Status: apicontainerstatus.ContainerStopped, + ExitCode: &exitCode, + Reason: reason, }) if err != nil { t.Fatal(err) @@ -317,9 +347,12 @@ func TestSubmitContainerStateChangeLongReason(t *testing.T) { err := client.SubmitContainerStateChange(api.ContainerStateChange{ TaskArn: "arn", ContainerName: "cont", - Status: apicontainerstatus.ContainerStopped, - ExitCode: &exitCode, - Reason: reason, + Container: &apicontainer.Container{ + NetworkModeUnsafe: testNetworkName, + }, + Status: apicontainerstatus.ContainerStopped, + ExitCode: &exitCode, + Reason: reason, }) if err != nil { t.Errorf("Unable to submit container state change: %v", err) @@ -1114,7 +1147,10 @@ func TestSubmitContainerStateChangeWhileTaskInPending(t *testing.T) { TaskArn: "arn", ContainerName: "container", RuntimeID: "runtimeid", - Status: apicontainerstatus.ContainerRunning, + Container: &apicontainer.Container{ + NetworkModeUnsafe: testNetworkName, + }, + Status: apicontainerstatus.ContainerRunning, }, }, } @@ -1152,3 +1188,136 @@ func extractTagsMapFromRegisterContainerInstanceInput(req *ecs.RegisterContainer } return tagsMap } + +func getTestContainerStateChange() api.ContainerStateChange { + testContainer := &apicontainer.Container{ + Name: "cont", + NetworkModeUnsafe: testNetworkName, + Ports: []apicontainer.PortBinding{ + { + ContainerPort: aws.Uint16(10), + HostPort: 10, + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: aws.Uint16(12), + HostPort: 12, + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPort: aws.Uint16(15), + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPortRange: aws.String("21-22"), + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPortRange: aws.String("96-97"), + Protocol: apicontainer.TransportProtocolTCP, + }, + }, + ContainerHasPortRange: true, + ContainerPortSet: map[int]struct{}{ + 10: {}, + 12: {}, + 15: {}, + }, + ContainerPortRangeMap: map[string]string{ + "21-22": "60001-60002", + "96-97": "47001-47002", + }, + } + + testContainerStateChange := api.ContainerStateChange{ + TaskArn: "arn", + ContainerName: "cont", + Status: apicontainerstatus.ContainerRunning, + Container: testContainer, + PortBindings: []apicontainer.PortBinding{ + { + ContainerPort: aws.Uint16(10), + HostPort: 10, + BindIP: "0.0.0.0", + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: aws.Uint16(12), + HostPort: 12, + BindIP: "1.2.3.4", + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPort: aws.Uint16(15), + HostPort: 20, + BindIP: "5.6.7.8", + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: aws.Uint16(21), + HostPort: 60001, + BindIP: "::", + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPort: aws.Uint16(22), + HostPort: 60002, + BindIP: "::", + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPort: aws.Uint16(96), + HostPort: 47001, + BindIP: "0.0.0.0", + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: aws.Uint16(97), + HostPort: 47002, + BindIP: "0.0.0.0", + Protocol: apicontainer.TransportProtocolTCP, + }, + }, + } + + return testContainerStateChange +} + +func TestGetNetworkBindings(t *testing.T) { + testContainerStateChange := getTestContainerStateChange() + expectedNetworkBindings := []*ecs.NetworkBinding{ + { + BindIP: strptr("0.0.0.0"), + ContainerPort: int64ptr(intptr(10)), + HostPort: int64ptr(intptr(10)), + Protocol: strptr("tcp"), + }, + { + BindIP: strptr("1.2.3.4"), + ContainerPort: int64ptr(intptr(12)), + HostPort: int64ptr(intptr(12)), + Protocol: strptr("udp"), + }, + { + BindIP: strptr("5.6.7.8"), + ContainerPort: int64ptr(intptr(15)), + HostPort: int64ptr(intptr(20)), + Protocol: strptr("tcp"), + }, + { + BindIP: strptr("::"), + ContainerPortRange: strptr("21-22"), + HostPortRange: strptr("60001-60002"), + Protocol: strptr("udp"), + }, + { + BindIP: strptr("0.0.0.0"), + ContainerPortRange: strptr("96-97"), + HostPortRange: strptr("47001-47002"), + Protocol: strptr("tcp"), + }, + } + + networkBindings := getNetworkBindings(testContainerStateChange, false) + assert.ElementsMatch(t, expectedNetworkBindings, networkBindings) +}