diff --git a/agent/acs/session/payload_responder.go b/agent/acs/session/payload_responder.go index 4bb85f2be1e..7c5a61e2cc4 100644 --- a/agent/acs/session/payload_responder.go +++ b/agent/acs/session/payload_responder.go @@ -23,6 +23,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/eventhandler" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" apiresource "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -40,7 +41,7 @@ type skipAddTaskComparatorFunc func(apitaskstatus.TaskStatus) bool // payloadMessageHandler implements PayloadMessageHandler interface defined in ecs-agent module. type payloadMessageHandler struct { taskEngine engine.TaskEngine - ecsClient api.ECSClient + ecsClient ecs.ECSClient dataClient data.Client taskHandler *eventhandler.TaskHandler credentialsManager credentials.Manager @@ -49,7 +50,7 @@ type payloadMessageHandler struct { // NewPayloadMessageHandler creates a new payloadMessageHandler. func NewPayloadMessageHandler(taskEngine engine.TaskEngine, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, dataClient data.Client, taskHandler *eventhandler.TaskHandler, credentialsManager credentials.Manager, diff --git a/agent/acs/session/payload_responder_test.go b/agent/acs/session/payload_responder_test.go index 8a4eb6d7bf8..96a9817c71b 100644 --- a/agent/acs/session/payload_responder_test.go +++ b/agent/acs/session/payload_responder_test.go @@ -22,8 +22,6 @@ import ( "sync" "testing" - "github.com/aws/amazon-ecs-agent/agent/api" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" @@ -35,6 +33,8 @@ import ( acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" apiresource "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" @@ -76,7 +76,7 @@ type testHelper struct { func setup(t *testing.T, acsResponseSender wsclient.RespondFunc) *testHelper { ctrl := gomock.NewController(t) taskEngine := mock_engine.NewMockTaskEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) dataClient := data.NewNoopClient() credentialsManager := credentials.NewManager() ctx := context.Background() @@ -1057,7 +1057,7 @@ func TestHandlePayloadMessageAddedFirelensData(t *testing.T) { func TestHandleInvalidTask(t *testing.T) { tester := setup(t, nil) - mockECSACSClient := mock_api.NewMockECSClient(tester.ctrl) + mockECSACSClient := mock_ecs.NewMockECSClient(tester.ctrl) taskHandler := eventhandler.NewTaskHandler(tester.ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), mockECSACSClient) tester.payloadMessageHandler.ecsClient = mockECSACSClient @@ -1070,8 +1070,8 @@ func TestHandleInvalidTask(t *testing.T) { wait := &sync.WaitGroup{} wait.Add(1) - mockECSACSClient.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { - assert.NotNil(t, change.Task) + mockECSACSClient.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { + assert.False(t, change.MetadataGetter.GetTaskIsNil()) wait.Done() }) diff --git a/agent/api/ecsclient/client.go b/agent/api/ecsclient/client.go deleted file mode 100644 index 5042394404d..00000000000 --- a/agent/api/ecsclient/client.go +++ /dev/null @@ -1,777 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package ecsclient - -import ( - "context" - "errors" - "fmt" - "runtime" - "strings" - "time" - - "github.com/aws/amazon-ecs-agent/agent/api" - "github.com/aws/amazon-ecs-agent/agent/config" - "github.com/aws/amazon-ecs-agent/agent/utils" - agentversion "github.com/aws/amazon-ecs-agent/agent/version" - apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" - apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" - "github.com/aws/amazon-ecs-agent/ecs-agent/async" - "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/instancecreds" - "github.com/aws/amazon-ecs-agent/ecs-agent/ec2" - "github.com/aws/amazon-ecs-agent/ecs-agent/httpclient" - "github.com/aws/amazon-ecs-agent/ecs-agent/logger" - commonutils "github.com/aws/amazon-ecs-agent/ecs-agent/utils" - "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" - - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/cihub/seelog" - "github.com/docker/docker/pkg/system" - "github.com/docker/go-connections/nat" -) - -const ( - ecsMaxImageDigestLength = 255 - ecsMaxContainerReasonLength = 255 - ecsMaxTaskReasonLength = 1024 - ecsMaxRuntimeIDLength = 255 - pollEndpointCacheTTL = 12 * time.Hour - azAttrName = "ecs.availability-zone" - cpuArchAttrName = "ecs.cpu-architecture" - osTypeAttrName = "ecs.os-type" - osFamilyAttrName = "ecs.os-family" - RoundtripTimeout = 5 * time.Second - // ecsMaxNetworkBindingsLength is the maximum length of the ecs.NetworkBindings list sent as part of the - // container state change payload. Currently, this is enforced only when containerPortRanges are requested. - ecsMaxNetworkBindingsLength = 100 - // Following constants used for SetInstanceIdentity retry with exponential backoff - setInstanceIdRetryTimeOut = 30 * time.Second - setInstanceIdRetryBackoffMin = 100 * time.Millisecond - setInstanceIdRetryBackoffMax = 5 * time.Second - setInstanceIdRetryBackoffJitter = 0.2 - setInstanceIdRetryBackoffMultiple = 2 -) - -// APIECSClient implements ECSClient -type APIECSClient struct { - credentialProvider *credentials.Credentials - config *config.Config - standardClient api.ECSSDK - submitStateChangeClient api.ECSSubmitStateSDK - ec2metadata ec2.EC2MetadataClient - pollEndpointCache async.TTLCache -} - -// NewECSClient creates a new ECSClient interface object -func NewECSClient( - credentialProvider *credentials.Credentials, - cfg *config.Config, - ec2MetadataClient ec2.EC2MetadataClient) api.ECSClient { - - var ecsConfig aws.Config - ecsConfig.Credentials = credentialProvider - ecsConfig.Region = &cfg.AWSRegion - ecsConfig.HTTPClient = httpclient.New(RoundtripTimeout, cfg.AcceptInsecureCert, agentversion.String(), config.OSType) - if cfg.APIEndpoint != "" { - ecsConfig.Endpoint = &cfg.APIEndpoint - } - standardClient := ecs.New(session.New(&ecsConfig)) - submitStateChangeClient := newSubmitStateChangeClient(&ecsConfig) - return &APIECSClient{ - credentialProvider: credentialProvider, - config: cfg, - standardClient: standardClient, - submitStateChangeClient: submitStateChangeClient, - ec2metadata: ec2MetadataClient, - pollEndpointCache: async.NewTTLCache(&async.TTL{Duration: pollEndpointCacheTTL}), - } -} - -// SetSDK overrides the SDK to the given one. This is useful for injecting a -// test implementation -func (client *APIECSClient) SetSDK(sdk api.ECSSDK) { - client.standardClient = sdk -} - -// SetSubmitStateChangeSDK overrides the SDK to the given one. This is useful -// for injecting a test implementation -func (client *APIECSClient) SetSubmitStateChangeSDK(sdk api.ECSSubmitStateSDK) { - client.submitStateChangeClient = sdk -} - -// CreateCluster creates a cluster from a given name and returns its arn -func (client *APIECSClient) CreateCluster(clusterName string) (string, error) { - resp, err := client.standardClient.CreateCluster(&ecs.CreateClusterInput{ClusterName: &clusterName}) - if err != nil { - seelog.Criticalf("Could not create cluster: %v", err) - return "", err - } - seelog.Infof("Created a cluster named: %s", clusterName) - return *resp.Cluster.ClusterName, nil -} - -// RegisterContainerInstance calculates the appropriate resources, creates -// the default cluster if necessary, and returns the registered -// ContainerInstanceARN if successful. Supplying a non-empty container -// instance ARN allows a container instance to update its registered -// resources. -func (client *APIECSClient) RegisterContainerInstance(containerInstanceArn string, attributes []*ecs.Attribute, - tags []*ecs.Tag, registrationToken string, platformDevices []*ecs.PlatformDevice, - outpostARN string) (string, string, error) { - - clusterRef := client.config.Cluster - // If our clusterRef is empty, we should try to create the default - if clusterRef == "" { - clusterRef = config.DefaultClusterName - defer func() { - // Update the config value to reflect the cluster we end up in - client.config.Cluster = clusterRef - }() - // Attempt to register without checking existence of the cluster so we don't require - // excess permissions in the case where the cluster already exists and is active - containerInstanceArn, availabilityzone, err := client.registerContainerInstance(clusterRef, - containerInstanceArn, attributes, tags, registrationToken, platformDevices, outpostARN) - if err == nil { - return containerInstanceArn, availabilityzone, nil - } - - // If trying to register fails because the default cluster doesn't exist, try to create the cluster before calling - // register again - if apierrors.IsClusterNotFoundError(err) { - clusterRef, err = client.CreateCluster(clusterRef) - if err != nil { - return "", "", err - } - } - } - return client.registerContainerInstance(clusterRef, containerInstanceArn, attributes, tags, registrationToken, - platformDevices, outpostARN) -} - -func (client *APIECSClient) registerContainerInstance(clusterRef string, containerInstanceArn string, - attributes []*ecs.Attribute, tags []*ecs.Tag, registrationToken string, - platformDevices []*ecs.PlatformDevice, outpostARN string) (string, string, error) { - - registerRequest := ecs.RegisterContainerInstanceInput{Cluster: &clusterRef} - var registrationAttributes []*ecs.Attribute - if containerInstanceArn != "" { - // We are re-connecting a previously registered instance, restored from snapshot. - registerRequest.ContainerInstanceArn = &containerInstanceArn - } else { - // This is a new instance, not previously registered. - // Custom attribute registration only happens on initial instance registration. - for _, attribute := range client.getCustomAttributes() { - seelog.Debugf("Added a new custom attribute %v=%v", - aws.StringValue(attribute.Name), - aws.StringValue(attribute.Value), - ) - registrationAttributes = append(registrationAttributes, attribute) - } - } - // Standard attributes are included with all registrations. - registrationAttributes = append(registrationAttributes, attributes...) - - // Add additional attributes such as the os type - registrationAttributes = append(registrationAttributes, client.getAdditionalAttributes()...) - registrationAttributes = append(registrationAttributes, client.getOutpostAttribute(outpostARN)...) - - registerRequest.Attributes = registrationAttributes - if len(tags) > 0 { - registerRequest.Tags = tags - } - registerRequest.PlatformDevices = platformDevices - registerRequest = client.setInstanceIdentity(registerRequest) - - resources, err := client.getResources() - if err != nil { - return "", "", err - } - - registerRequest.TotalResources = resources - - registerRequest.ClientToken = ®istrationToken - resp, err := client.standardClient.RegisterContainerInstance(®isterRequest) - if err != nil { - seelog.Errorf("Unable to register as a container instance with ECS: %v", err) - return "", "", err - } - - var availabilityzone = "" - if resp != nil { - for _, attr := range resp.ContainerInstance.Attributes { - if aws.StringValue(attr.Name) == azAttrName { - availabilityzone = aws.StringValue(attr.Value) - break - } - } - } - - seelog.Info("Registered container instance with cluster!") - err = validateRegisteredAttributes(registerRequest.Attributes, resp.ContainerInstance.Attributes) - return aws.StringValue(resp.ContainerInstance.ContainerInstanceArn), availabilityzone, err -} - -func (client *APIECSClient) setInstanceIdentity(registerRequest ecs.RegisterContainerInstanceInput) ecs.RegisterContainerInstanceInput { - instanceIdentityDoc := "" - instanceIdentitySignature := "" - - if client.config.NoIID { - seelog.Info("Fetching Instance ID Document has been disabled") - registerRequest.InstanceIdentityDocument = &instanceIdentityDoc - registerRequest.InstanceIdentityDocumentSignature = &instanceIdentitySignature - return registerRequest - } - - iidRetrieved := true - backoff := retry.NewExponentialBackoff(setInstanceIdRetryBackoffMin, setInstanceIdRetryBackoffMax, - setInstanceIdRetryBackoffJitter, setInstanceIdRetryBackoffMultiple) - ctx, cancel := context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut) - defer cancel() - err := retry.RetryWithBackoffCtx(ctx, backoff, func() error { - var attemptErr error - seelog.Debugf("Attempting to get Instance Identity Document") - instanceIdentityDoc, attemptErr = client.ec2metadata.GetDynamicData(ec2.InstanceIdentityDocumentResource) - if attemptErr != nil { - seelog.Debugf("Unable to get instance identity document, retrying: %v", attemptErr) - // force credentials to expire in case they are stale but not expired - client.credentialProvider.Expire() - client.credentialProvider = instancecreds.GetCredentials(client.config.External.Enabled()) - return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr) - } - seelog.Debugf("Successfully retrieved Instance Identity Document") - return nil - }) - if err != nil { - seelog.Errorf("Unable to get instance identity document: %v", err) - iidRetrieved = false - } - registerRequest.InstanceIdentityDocument = &instanceIdentityDoc - - if iidRetrieved { - instanceIdentitySignature, err = client.ec2metadata.GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) - if err != nil { - seelog.Errorf("Unable to get instance identity signature: %v", err) - } - } - - registerRequest.InstanceIdentityDocumentSignature = &instanceIdentitySignature - return registerRequest -} - -func attributesToMap(attributes []*ecs.Attribute) map[string]string { - attributeMap := make(map[string]string) - attribs := attributes - for _, attribute := range attribs { - attributeMap[aws.StringValue(attribute.Name)] = aws.StringValue(attribute.Value) - } - return attributeMap -} - -func findMissingAttributes(expectedAttributes, actualAttributes map[string]string) ([]string, error) { - missingAttributes := make([]string, 0) - var err error - for key, val := range expectedAttributes { - if actualAttributes[key] != val { - missingAttributes = append(missingAttributes, key) - } else { - seelog.Tracef("Response contained expected value for attribute %v", key) - } - } - if len(missingAttributes) > 0 { - err = apierrors.NewAttributeError("Attribute validation failed") - } - return missingAttributes, err -} - -func (client *APIECSClient) getResources() ([]*ecs.Resource, error) { - // Micro-optimization, the pointer to this is used multiple times below - integerStr := "INTEGER" - - cpu, mem := getCpuAndMemory() - remainingMem := mem - int64(client.config.ReservedMemory) - seelog.Infof("Remaining mem: %d", remainingMem) - if remainingMem < 0 { - return nil, fmt.Errorf( - "api register-container-instance: reserved memory is higher than available memory on the host, total memory: %d, reserved: %d", - mem, client.config.ReservedMemory) - } - - cpuResource := ecs.Resource{ - Name: utils.Strptr("CPU"), - Type: &integerStr, - IntegerValue: &cpu, - } - memResource := ecs.Resource{ - Name: utils.Strptr("MEMORY"), - Type: &integerStr, - IntegerValue: &remainingMem, - } - portResource := ecs.Resource{ - Name: utils.Strptr("PORTS"), - Type: utils.Strptr("STRINGSET"), - StringSetValue: commonutils.Uint16SliceToStringSlice(client.config.ReservedPorts), - } - udpPortResource := ecs.Resource{ - Name: utils.Strptr("PORTS_UDP"), - Type: utils.Strptr("STRINGSET"), - StringSetValue: commonutils.Uint16SliceToStringSlice(client.config.ReservedPortsUDP), - } - - return []*ecs.Resource{&cpuResource, &memResource, &portResource, &udpPortResource}, nil -} - -// GetHostResources calling getHostResources to get a list of CPU, MEMORY, PORTS and PORTS_UPD resources -// and return a resourceMap that map the resource name to each resource -func (client *APIECSClient) GetHostResources() (map[string]*ecs.Resource, error) { - resources, err := client.getResources() - if err != nil { - return nil, err - } - resourceMap := make(map[string]*ecs.Resource) - for _, resource := range resources { - if *resource.Name == "PORTS" { - // Except for RCI, TCP Ports are named as PORTS_TCP in agent for Host Resources purpose - resource.Name = utils.Strptr("PORTS_TCP") - } - resourceMap[*resource.Name] = resource - } - return resourceMap, nil -} - -func getCpuAndMemory() (int64, int64) { - memInfo, err := system.ReadMemInfo() - mem := int64(0) - if err == nil { - mem = memInfo.MemTotal / 1024 / 1024 // MiB - } else { - seelog.Errorf("Unable to get memory info: %v", err) - } - - cpu := runtime.NumCPU() * 1024 - - return int64(cpu), mem -} - -func validateRegisteredAttributes(expectedAttributes, actualAttributes []*ecs.Attribute) error { - var err error - expectedAttributesMap := attributesToMap(expectedAttributes) - actualAttributesMap := attributesToMap(actualAttributes) - missingAttributes, err := findMissingAttributes(expectedAttributesMap, actualAttributesMap) - if err != nil { - msg := strings.Join(missingAttributes, ",") - seelog.Errorf("Error registering attributes: %v", msg) - } - return err -} - -func (client *APIECSClient) getAdditionalAttributes() []*ecs.Attribute { - attrs := []*ecs.Attribute{ - { - Name: aws.String(osTypeAttrName), - Value: aws.String(config.OSType), - }, - { - Name: aws.String(osFamilyAttrName), - Value: aws.String(config.GetOSFamily()), - }, - } - // Send cpu arch attribute directly when running on external capacity. When running on EC2, this is not needed - // since the cpu arch is reported via instance identity doc in that case. - if client.config.External.Enabled() { - attrs = append(attrs, &ecs.Attribute{ - Name: aws.String(cpuArchAttrName), - Value: aws.String(getCPUArch()), - }) - } - return attrs -} - -func (client *APIECSClient) getOutpostAttribute(outpostARN string) []*ecs.Attribute { - if len(outpostARN) > 0 { - return []*ecs.Attribute{ - { - Name: aws.String("ecs.outpost-arn"), - Value: aws.String(outpostARN), - }, - } - } - return []*ecs.Attribute{} -} - -func (client *APIECSClient) getCustomAttributes() []*ecs.Attribute { - var attributes []*ecs.Attribute - for attribute, value := range client.config.InstanceAttributes { - attributes = append(attributes, &ecs.Attribute{ - Name: aws.String(attribute), - Value: aws.String(value), - }) - } - return attributes -} - -func (client *APIECSClient) SubmitTaskStateChange(change api.TaskStateChange) error { - // Submit attachment state change - if change.Attachment != nil { - var attachments []*ecs.AttachmentStateChange - - eniStatus := change.Attachment.Status.String() - attachments = []*ecs.AttachmentStateChange{ - { - AttachmentArn: aws.String(change.Attachment.AttachmentARN), - Status: aws.String(eniStatus), - }, - } - - _, err := client.submitStateChangeClient.SubmitTaskStateChange(&ecs.SubmitTaskStateChangeInput{ - Cluster: aws.String(client.config.Cluster), - Task: aws.String(change.TaskARN), - Attachments: attachments, - }) - if err != nil { - seelog.Warnf("Could not submit an attachment state change: %v", err) - return err - } - - return nil - } - - status := change.Status.BackendStatus() - - req := ecs.SubmitTaskStateChangeInput{ - Cluster: aws.String(client.config.Cluster), - Task: aws.String(change.TaskARN), - Status: aws.String(status), - Reason: aws.String(trimString(change.Reason, ecsMaxTaskReasonLength)), - PullStartedAt: change.PullStartedAt, - PullStoppedAt: change.PullStoppedAt, - ExecutionStoppedAt: change.ExecutionStoppedAt, - } - - for _, managedAgentEvent := range change.ManagedAgents { - if mgspl := client.buildManagedAgentStateChangePayload(managedAgentEvent); mgspl != nil { - req.ManagedAgents = append(req.ManagedAgents, mgspl) - } - } - - containerEvents := make([]*ecs.ContainerStateChange, len(change.Containers)) - for i, containerEvent := range change.Containers { - payload, err := client.buildContainerStateChangePayload(containerEvent, client.config.ShouldExcludeIPv6PortBinding.Enabled()) - if err != nil { - seelog.Errorf("Could not submit task state change: [%s]: %v", change.String(), err) - return err - } - containerEvents[i] = payload - } - - req.Containers = containerEvents - - _, err := client.submitStateChangeClient.SubmitTaskStateChange(&req) - if err != nil { - seelog.Warnf("Could not submit task state change: [%s]: %v", change.String(), err) - return err - } - - return nil -} - -func trimString(inputString string, maxLen int) string { - if len(inputString) > maxLen { - trimmed := inputString[0:maxLen] - return trimmed - } else { - return inputString - } -} - -func (client *APIECSClient) buildManagedAgentStateChangePayload(change api.ManagedAgentStateChange) *ecs.ManagedAgentStateChange { - if !change.Status.ShouldReportToBackend() { - seelog.Warnf("Not submitting unsupported managed agent state %s for container %s in task %s", - change.Status.String(), change.Container.Name, change.TaskArn) - return nil - } - var trimmedReason *string - if change.Reason != "" { - trimmedReason = aws.String(trimString(change.Reason, ecsMaxContainerReasonLength)) - } - return &ecs.ManagedAgentStateChange{ - ManagedAgentName: aws.String(change.Name), - ContainerName: aws.String(change.Container.Name), - Status: aws.String(change.Status.String()), - Reason: trimmedReason, - } -} - -func (client *APIECSClient) buildContainerStateChangePayload(change api.ContainerStateChange, shouldExcludeIPv6PortBinding bool) (*ecs.ContainerStateChange, error) { - statechange := &ecs.ContainerStateChange{ - ContainerName: aws.String(change.ContainerName), - } - if change.RuntimeID != "" { - trimmedRuntimeID := trimString(change.RuntimeID, ecsMaxRuntimeIDLength) - statechange.RuntimeId = aws.String(trimmedRuntimeID) - } - if change.Reason != "" { - trimmedReason := trimString(change.Reason, ecsMaxContainerReasonLength) - statechange.Reason = aws.String(trimmedReason) - } - if change.ImageDigest != "" { - trimmedImageDigest := trimString(change.ImageDigest, ecsMaxImageDigestLength) - statechange.ImageDigest = aws.String(trimmedImageDigest) - } - status := change.Status - - if status != apicontainerstatus.ContainerStopped && status != apicontainerstatus.ContainerRunning { - seelog.Warnf("Not submitting unsupported upstream container state %s for container %s in task %s", - status.String(), change.ContainerName, change.TaskArn) - return nil, nil - } - stat := change.Status.String() - if stat == "DEAD" { - stat = apicontainerstatus.ContainerStopped.String() - } - statechange.Status = aws.String(stat) - - if change.ExitCode != nil { - exitCode := int64(aws.IntValue(change.ExitCode)) - statechange.ExitCode = aws.Int64(exitCode) - } - - networkBindings := getNetworkBindings(change, shouldExcludeIPv6PortBinding) - // we enforce a limit on the no. of network bindings for containers with at-least 1 port range requested. - // this limit is enforced by ECS, and we fail early and don't call SubmitContainerStateChange. - if change.Container.HasPortRange() && len(networkBindings) > ecsMaxNetworkBindingsLength { - return nil, fmt.Errorf("no. of network bindings %v is more than the maximum supported no. %v, "+ - "container: %s "+"task: %s", len(networkBindings), ecsMaxNetworkBindingsLength, change.ContainerName, change.TaskArn) - } - statechange.NetworkBindings = networkBindings - - return statechange, nil -} - -// ProtocolBindIP used to store protocol and bindIP information associated to a particular host port -type ProtocolBindIP struct { - protocol string - bindIP string -} - -// getNetworkBindings returns the list of networkingBindings, sent to ECS as part of the container state change payload -func getNetworkBindings(change api.ContainerStateChange, shouldExcludeIPv6PortBinding bool) []*ecs.NetworkBinding { - networkBindings := []*ecs.NetworkBinding{} - // hostPortToProtocolBindIPMap is a map to store protocol and bindIP information associated to host ports - // that belong to a range. This is used in case when there are multiple protocol/bindIP combinations associated to a - // port binding. example: when both IPv4 and IPv6 bindIPs are populated by docker and shouldExcludeIPv6PortBinding is false. - hostPortToProtocolBindIPMap := map[int64][]ProtocolBindIP{} - - // ContainerPortSet consists of singular ports, and ports that belong to a range, but for which we were not able to - // find contiguous host ports and ask docker to pick instead. - containerPortSet := change.Container.GetContainerPortSet() - // each entry in the ContainerPortRangeMap implies that we found a contiguous host port range for the same - containerPortRangeMap := change.Container.GetContainerPortRangeMap() - - for _, binding := range change.PortBindings { - if binding.BindIP == "::" && shouldExcludeIPv6PortBinding { - seelog.Debugf("Exclude IPv6 port binding %v for container %s in task %s", binding, change.ContainerName, change.TaskArn) - continue - } - - hostPort := int64(binding.HostPort) - containerPort := int64(binding.ContainerPort) - bindIP := binding.BindIP - protocol := binding.Protocol.String() - - // create network binding for each containerPort that exists in the singular ContainerPortSet - // for container ports that belong to a range, we'll have 1 consolidated network binding for the range - if _, ok := containerPortSet[int(containerPort)]; ok { - networkBindings = append(networkBindings, &ecs.NetworkBinding{ - BindIP: aws.String(bindIP), - ContainerPort: aws.Int64(containerPort), - HostPort: aws.Int64(hostPort), - Protocol: aws.String(protocol), - }) - } else { - // populate hostPortToProtocolBindIPMap – this is used below when we construct network binding for ranges. - hostPortToProtocolBindIPMap[hostPort] = append(hostPortToProtocolBindIPMap[hostPort], - ProtocolBindIP{ - protocol: protocol, - bindIP: bindIP, - }) - } - } - - for containerPortRange, hostPortRange := range containerPortRangeMap { - // we check for protocol and bindIP information associated to any one of the host ports from the hostPortRange, - // all ports belonging to the same range share this information. - hostPort, _, _ := nat.ParsePortRangeToInt(hostPortRange) - if val, ok := hostPortToProtocolBindIPMap[int64(hostPort)]; ok { - for _, v := range val { - networkBindings = append(networkBindings, &ecs.NetworkBinding{ - BindIP: aws.String(v.bindIP), - ContainerPortRange: aws.String(containerPortRange), - HostPortRange: aws.String(hostPortRange), - Protocol: aws.String(v.protocol), - }) - } - } - } - - return networkBindings -} - -func (client *APIECSClient) SubmitContainerStateChange(change api.ContainerStateChange) error { - pl, err := client.buildContainerStateChangePayload(change, client.config.ShouldExcludeIPv6PortBinding.Enabled()) - if err != nil { - seelog.Errorf("Could not build container state change payload: [%s]: %v", change.String(), err) - return err - } else if pl == nil { - return nil - } - - _, err = client.submitStateChangeClient.SubmitContainerStateChange(&ecs.SubmitContainerStateChangeInput{ - Cluster: aws.String(client.config.Cluster), - ContainerName: aws.String(change.ContainerName), - ExitCode: pl.ExitCode, - ManagedAgents: pl.ManagedAgents, - NetworkBindings: pl.NetworkBindings, - Reason: pl.Reason, - RuntimeId: pl.RuntimeId, - Status: pl.Status, - Task: aws.String(change.TaskArn), - }) - if err != nil { - seelog.Warnf("Could not submit container state change: [%s]: %v", change.String(), err) - return err - } - return nil -} - -func (client *APIECSClient) SubmitAttachmentStateChange(change api.AttachmentStateChange) error { - attachmentStatus := change.Attachment.GetAttachmentStatus() - - req := ecs.SubmitAttachmentStateChangesInput{ - Cluster: &client.config.Cluster, - Attachments: []*ecs.AttachmentStateChange{ - { - AttachmentArn: aws.String(change.Attachment.GetAttachmentARN()), - Status: aws.String(attachmentStatus.String()), - }, - }, - } - - _, err := client.submitStateChangeClient.SubmitAttachmentStateChanges(&req) - if err != nil { - seelog.Warnf("Could not submit attachment state change [%s]: %v", change.String(), err) - return err - } - - return nil -} - -func (client *APIECSClient) DiscoverPollEndpoint(containerInstanceArn string) (string, error) { - resp, err := client.discoverPollEndpoint(containerInstanceArn) - if err != nil { - return "", err - } - - return aws.StringValue(resp.Endpoint), nil -} - -func (client *APIECSClient) DiscoverTelemetryEndpoint(containerInstanceArn string) (string, error) { - resp, err := client.discoverPollEndpoint(containerInstanceArn) - if err != nil { - return "", err - } - if resp.TelemetryEndpoint == nil { - return "", errors.New("No telemetry endpoint returned; nil") - } - - return aws.StringValue(resp.TelemetryEndpoint), nil -} - -func (client *APIECSClient) DiscoverServiceConnectEndpoint(containerInstanceArn string) (string, error) { - resp, err := client.discoverPollEndpoint(containerInstanceArn) - if err != nil { - return "", err - } - if resp.ServiceConnectEndpoint == nil { - return "", errors.New("No ServiceConnect endpoint returned; nil") - } - - return aws.StringValue(resp.ServiceConnectEndpoint), nil -} - -func (client *APIECSClient) discoverPollEndpoint(containerInstanceArn string) (*ecs.DiscoverPollEndpointOutput, error) { - // Try getting an entry from the cache - cachedEndpoint, expired, found := client.pollEndpointCache.Get(containerInstanceArn) - if !expired && found { - // Cache hit and not expired. Return the output. - if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok { - logger.Info("Using cached DiscoverPollEndpoint", logger.Fields{ - "endpoint": aws.StringValue(output.Endpoint), - "telemetryEndpoint": aws.StringValue(output.TelemetryEndpoint), - "serviceConnectEndpoint": aws.StringValue(output.ServiceConnectEndpoint), - "containerInstanceARN": containerInstanceArn, - }) - return output, nil - } - } - - // Cache miss or expired, invoke the ECS DiscoverPollEndpoint API. - seelog.Debugf("Invoking DiscoverPollEndpoint for '%s'", containerInstanceArn) - output, err := client.standardClient.DiscoverPollEndpoint(&ecs.DiscoverPollEndpointInput{ - ContainerInstance: &containerInstanceArn, - Cluster: &client.config.Cluster, - }) - if err != nil { - // if we got an error calling the API, fallback to an expired cached endpoint if - // we have it. - if expired { - if output, ok := cachedEndpoint.(*ecs.DiscoverPollEndpointOutput); ok { - logger.Info("Error calling DiscoverPollEndpoint. Using cached-but-expired endpoint as a fallback.", logger.Fields{ - "endpoint": aws.StringValue(output.Endpoint), - "telemetryEndpoint": aws.StringValue(output.TelemetryEndpoint), - "serviceConnectEndpoint": aws.StringValue(output.ServiceConnectEndpoint), - "containerInstanceARN": containerInstanceArn, - }) - return output, nil - } - } - return nil, err - } - - // Cache the response from ECS. - client.pollEndpointCache.Set(containerInstanceArn, output) - return output, nil -} - -func (client *APIECSClient) GetResourceTags(resourceArn string) ([]*ecs.Tag, error) { - output, err := client.standardClient.ListTagsForResource(&ecs.ListTagsForResourceInput{ - ResourceArn: &resourceArn, - }) - if err != nil { - return nil, err - } - return output.Tags, nil -} - -func (client *APIECSClient) UpdateContainerInstancesState(instanceARN string, status string) error { - seelog.Debugf("Invoking UpdateContainerInstancesState, status='%s' instanceARN='%s'", status, instanceARN) - _, err := client.standardClient.UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{ - ContainerInstances: []*string{aws.String(instanceARN)}, - Status: aws.String(status), - Cluster: &client.config.Cluster, - }) - return err -} diff --git a/agent/api/ecsclient/client_test.go b/agent/api/ecsclient/client_test.go deleted file mode 100644 index d9b54539afd..00000000000 --- a/agent/api/ecsclient/client_test.go +++ /dev/null @@ -1,1340 +0,0 @@ -//go:build unit -// +build unit - -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package ecsclient - -import ( - "errors" - "fmt" - "reflect" - "strings" - "testing" - "time" - - "github.com/golang/mock/gomock" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/aws/amazon-ecs-agent/agent/api" - apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" - "github.com/aws/amazon-ecs-agent/agent/config" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment" - apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" - apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" - "github.com/aws/amazon-ecs-agent/ecs-agent/async" - mock_async "github.com/aws/amazon-ecs-agent/ecs-agent/async/mocks" - "github.com/aws/amazon-ecs-agent/ecs-agent/ec2" - mock_ec2 "github.com/aws/amazon-ecs-agent/ecs-agent/ec2/mocks" - ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/credentials" -) - -const ( - configuredCluster = "mycluster" - iid = "instanceIdentityDocument" - iidSignature = "signature" - registrationToken = "clientToken" - testNetworkName = "bridge" -) - -var ( - iidResponse = []byte(iid) - iidSignatureResponse = []byte(iidSignature) - containerInstanceTags = []*ecs.Tag{ - { - Key: aws.String("my_key1"), - Value: aws.String("my_val1"), - }, - { - Key: aws.String("my_key2"), - Value: aws.String("my_val2"), - }, - } - containerInstanceTagsMap = map[string]string{ - "my_key1": "my_val1", - "my_key2": "my_val2", - } - testManagedAgents = []*ecs.ManagedAgentStateChange{ - { - ManagedAgentName: aws.String("test_managed_agent"), - ContainerName: aws.String("test_container"), - Status: aws.String("RUNNING"), - Reason: aws.String("test_reason"), - }, - } -) - -func NewMockClient(ctrl *gomock.Controller, - ec2Metadata ec2.EC2MetadataClient, - additionalAttributes map[string]string) (api.ECSClient, *mock_api.MockECSSDK, *mock_api.MockECSSubmitStateSDK) { - - return NewMockClientWithConfig(ctrl, ec2Metadata, additionalAttributes, - &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-east-1", - InstanceAttributes: additionalAttributes, - ShouldExcludeIPv6PortBinding: config.BooleanDefaultTrue{Value: config.ExplicitlyEnabled}, - }) -} - -func NewMockClientWithConfig(ctrl *gomock.Controller, - ec2Metadata ec2.EC2MetadataClient, - additionalAttributes map[string]string, - cfg *config.Config) (api.ECSClient, *mock_api.MockECSSDK, *mock_api.MockECSSubmitStateSDK) { - client := NewECSClient(credentials.AnonymousCredentials, cfg, ec2Metadata) - mockSDK := mock_api.NewMockECSSDK(ctrl) - mockSubmitStateSDK := mock_api.NewMockECSSubmitStateSDK(ctrl) - client.(*APIECSClient).SetSDK(mockSDK) - client.(*APIECSClient).SetSubmitStateChangeSDK(mockSubmitStateSDK) - return client, mockSDK, mockSubmitStateSDK -} - -type containerSubmitInputMatcher struct { - ecs.SubmitContainerStateChangeInput -} - -type taskSubmitInputMatcher struct { - ecs.SubmitTaskStateChangeInput -} - -func strptr(s string) *string { return &s } -func intptr(i int) *int { return &i } -func int64ptr(i *int) *int64 { - if i == nil { - return nil - } - j := int64(*i) - return &j -} -func equal(lhs, rhs interface{}) bool { - return reflect.DeepEqual(lhs, rhs) -} -func (lhs *containerSubmitInputMatcher) Matches(x interface{}) bool { - rhs := x.(*ecs.SubmitContainerStateChangeInput) - - return (equal(lhs.Cluster, rhs.Cluster) && - equal(lhs.ContainerName, rhs.ContainerName) && - equal(lhs.ExitCode, rhs.ExitCode) && - equal(lhs.NetworkBindings, rhs.NetworkBindings) && - equal(lhs.ManagedAgents, rhs.ManagedAgents) && - equal(lhs.Reason, rhs.Reason) && - equal(lhs.Status, rhs.Status) && - equal(lhs.Task, rhs.Task)) -} - -func (lhs *containerSubmitInputMatcher) String() string { - return fmt.Sprintf("%+v", *lhs) -} - -func (lhs *taskSubmitInputMatcher) Matches(x interface{}) bool { - rhs := x.(*ecs.SubmitTaskStateChangeInput) - - if !(equal(lhs.Cluster, rhs.Cluster) && - equal(lhs.Task, rhs.Task) && - equal(lhs.Status, rhs.Status) && - equal(lhs.Reason, rhs.Reason) && - equal(len(lhs.Attachments), len(rhs.Attachments))) { - return false - } - - if len(lhs.Attachments) != 0 { - for i := range lhs.Attachments { - if !(equal(lhs.Attachments[i].Status, rhs.Attachments[i].Status) && - equal(lhs.Attachments[i].AttachmentArn, rhs.Attachments[i].AttachmentArn)) { - return false - } - } - } - - if len(lhs.Containers) != 0 && !equal(lhs.Containers, rhs.Containers) { - return false - } - - return true -} - -func (lhs *taskSubmitInputMatcher) String() string { - return fmt.Sprintf("%+v", *lhs) -} - -func TestSubmitContainerStateChange(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, _, mockSubmitStateClient := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - mockSubmitStateClient.EXPECT().SubmitContainerStateChange(&containerSubmitInputMatcher{ - ecs.SubmitContainerStateChangeInput{ - Cluster: strptr(configuredCluster), - Task: strptr("arn"), - ContainerName: strptr("cont"), - RuntimeId: strptr("runtime id"), - Status: strptr("RUNNING"), - NetworkBindings: []*ecs.NetworkBinding{ - { - BindIP: strptr("1.2.3.4"), - ContainerPort: int64ptr(intptr(1)), - HostPort: int64ptr(intptr(2)), - Protocol: strptr("tcp"), - }, - { - BindIP: strptr("2.2.3.4"), - ContainerPort: int64ptr(intptr(3)), - HostPort: int64ptr(intptr(4)), - Protocol: strptr("udp"), - }, - { - BindIP: strptr("5.6.7.8"), - ContainerPortRange: strptr("11-12"), - HostPortRange: strptr("11-12"), - Protocol: strptr("udp"), - }, - }, - }, - }) - err := client.SubmitContainerStateChange(api.ContainerStateChange{ - TaskArn: "arn", - ContainerName: "cont", - RuntimeID: "runtime id", - Status: apicontainerstatus.ContainerRunning, - Container: &apicontainer.Container{ - ContainerArn: "arn", - NetworkModeUnsafe: testNetworkName, - ContainerHasPortRange: true, - ContainerPortSet: map[int]struct{}{ - 1: {}, - 3: {}, - }, - ContainerPortRangeMap: map[string]string{ - "11-12": "11-12", - }, - }, - PortBindings: []apicontainer.PortBinding{ - { - BindIP: "1.2.3.4", - ContainerPort: 1, - HostPort: 2, - }, - { - BindIP: "2.2.3.4", - ContainerPort: 3, - HostPort: 4, - Protocol: apicontainer.TransportProtocolUDP, - }, - { - BindIP: "5.6.7.8", - ContainerPort: 11, - HostPort: 11, - Protocol: apicontainer.TransportProtocolUDP, - }, - { - BindIP: "5.6.7.8", - ContainerPort: 12, - HostPort: 12, - Protocol: apicontainer.TransportProtocolUDP, - }, - }, - }) - if err != nil { - t.Errorf("Unable to submit container state change: %v", err) - } -} - -func TestSubmitContainerStateChangeFull(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, _, mockSubmitStateClient := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - exitCode := 20 - reason := "I exited" - - mockSubmitStateClient.EXPECT().SubmitContainerStateChange(&containerSubmitInputMatcher{ - ecs.SubmitContainerStateChangeInput{ - Cluster: strptr(configuredCluster), - Task: strptr("arn"), - ContainerName: strptr("cont"), - RuntimeId: strptr("runtime id"), - Status: strptr("STOPPED"), - ExitCode: int64ptr(&exitCode), - Reason: strptr(reason), - NetworkBindings: []*ecs.NetworkBinding{}, - }, - }) - err := client.SubmitContainerStateChange(api.ContainerStateChange{ - TaskArn: "arn", - ContainerName: "cont", - RuntimeID: "runtime id", - Status: apicontainerstatus.ContainerStopped, - ExitCode: &exitCode, - Reason: reason, - Container: &apicontainer.Container{ - NetworkModeUnsafe: testNetworkName, - }, - PortBindings: []apicontainer.PortBinding{ - {}, - }, - }) - if err != nil { - t.Errorf("Unable to submit container state change: %v", err) - } -} - -func TestSubmitContainerStateChangeReason(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, _, mockSubmitStateClient := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - exitCode := 20 - reason := strings.Repeat("a", ecsMaxContainerReasonLength) - - mockSubmitStateClient.EXPECT().SubmitContainerStateChange(&containerSubmitInputMatcher{ - ecs.SubmitContainerStateChangeInput{ - Cluster: strptr(configuredCluster), - Task: strptr("arn"), - ContainerName: strptr("cont"), - Status: strptr("STOPPED"), - ExitCode: int64ptr(&exitCode), - Reason: strptr(reason), - NetworkBindings: []*ecs.NetworkBinding{}, - }, - }) - err := client.SubmitContainerStateChange(api.ContainerStateChange{ - TaskArn: "arn", - ContainerName: "cont", - Container: &apicontainer.Container{ - NetworkModeUnsafe: testNetworkName, - }, - Status: apicontainerstatus.ContainerStopped, - ExitCode: &exitCode, - Reason: reason, - }) - if err != nil { - t.Fatal(err) - } -} - -func TestSubmitContainerStateChangeLongReason(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, _, mockSubmitStateClient := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - exitCode := 20 - trimmedReason := strings.Repeat("a", ecsMaxContainerReasonLength) - reason := strings.Repeat("a", ecsMaxContainerReasonLength+1) - - mockSubmitStateClient.EXPECT().SubmitContainerStateChange(&containerSubmitInputMatcher{ - ecs.SubmitContainerStateChangeInput{ - Cluster: strptr(configuredCluster), - Task: strptr("arn"), - ContainerName: strptr("cont"), - Status: strptr("STOPPED"), - ExitCode: int64ptr(&exitCode), - Reason: strptr(trimmedReason), - NetworkBindings: []*ecs.NetworkBinding{}, - }, - }) - err := client.SubmitContainerStateChange(api.ContainerStateChange{ - TaskArn: "arn", - ContainerName: "cont", - Container: &apicontainer.Container{ - NetworkModeUnsafe: testNetworkName, - }, - Status: apicontainerstatus.ContainerStopped, - ExitCode: &exitCode, - Reason: reason, - }) - if err != nil { - t.Errorf("Unable to submit container state change: %v", err) - } -} - -func buildAttributeList(capabilities []string, attributes map[string]string) []*ecs.Attribute { - var rv []*ecs.Attribute - for _, capability := range capabilities { - rv = append(rv, &ecs.Attribute{Name: aws.String(capability)}) - } - for key, value := range attributes { - rv = append(rv, &ecs.Attribute{Name: aws.String(key), Value: aws.String(value)}) - } - return rv -} - -func TestRegisterContainerInstance(t *testing.T) { - testCases := []struct { - name string - cfg *config.Config - }{ - { - name: "retry GetDynamicData", - cfg: &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-west-2", - }, - }, - { - name: "basic case", - cfg: &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-west-2", - }, - }, - { - name: "no instance identity doc", - cfg: &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-west-2", - NoIID: true, - }, - }, - { - name: "on prem", - cfg: &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-west-2", - NoIID: true, - External: config.BooleanDefaultFalse{Value: config.ExplicitlyEnabled}, - }, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(mockCtrl) - additionalAttributes := map[string]string{"my_custom_attribute": "Custom_Value1", - "my_other_custom_attribute": "Custom_Value2", - } - tc.cfg.InstanceAttributes = additionalAttributes - client, mc, _ := NewMockClientWithConfig(mockCtrl, mockEC2Metadata, additionalAttributes, tc.cfg) - - fakeCapabilities := []string{"capability1", "capability2"} - expectedAttributes := map[string]string{ - "ecs.os-type": config.OSType, - "ecs.os-family": config.GetOSFamily(), - "my_custom_attribute": "Custom_Value1", - "my_other_custom_attribute": "Custom_Value2", - "ecs.availability-zone": "us-west-2b", - "ecs.outpost-arn": "test:arn:outpost", - cpuArchAttrName: getCPUArch(), - } - capabilities := buildAttributeList(fakeCapabilities, nil) - platformDevices := []*ecs.PlatformDevice{ - { - Id: aws.String("id1"), - Type: aws.String(ecs.PlatformDeviceTypeGpu), - }, - { - Id: aws.String("id2"), - Type: aws.String(ecs.PlatformDeviceTypeGpu), - }, - { - Id: aws.String("id3"), - Type: aws.String(ecs.PlatformDeviceTypeGpu), - }, - } - - expectedIID := iid - expectedIIDSig := iidSignature - if tc.cfg.NoIID { - expectedIID = "" - expectedIIDSig = "" - } else if tc.name == "retry GetDynamicData" { - gomock.InOrder( - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return("", errors.New("fake unit test error")), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return(expectedIID, nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return(expectedIIDSig, nil), - ) - } else { - //basic case - gomock.InOrder( - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return(expectedIID, nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return(expectedIIDSig, nil), - ) - } - - var expectedNumOfAttributes int - if !tc.cfg.External.Enabled() { - // 2 capability attributes: capability1, capability2 - // and 5 other attributes: ecs.os-type, ecs.os-family, ecs.outpost-arn, my_custom_attribute, my_other_custom_attribute. - expectedNumOfAttributes = 7 - } else { - // One more attribute for external case: ecs.cpu-architecture - expectedNumOfAttributes = 8 - } - - gomock.InOrder( - mc.EXPECT().RegisterContainerInstance(gomock.Any()).Do(func(req *ecs.RegisterContainerInstanceInput) { - assert.Nil(t, req.ContainerInstanceArn) - assert.Equal(t, configuredCluster, *req.Cluster, "Wrong cluster") - assert.Equal(t, registrationToken, *req.ClientToken, "Wrong client token") - assert.Equal(t, expectedIID, *req.InstanceIdentityDocument, "Wrong IID") - assert.Equal(t, expectedIIDSig, *req.InstanceIdentityDocumentSignature, "Wrong IID sig") - assert.Equal(t, 4, len(req.TotalResources), "Wrong length of TotalResources") - resource, ok := findResource(req.TotalResources, "PORTS_UDP") - require.True(t, ok, `Could not find resource "PORTS_UDP"`) - assert.Equal(t, "STRINGSET", *resource.Type, `Wrong type for resource "PORTS_UDP"`) - assert.Equal(t, expectedNumOfAttributes, len(req.Attributes), "Wrong number of Attributes") - attrs := attributesToMap(req.Attributes) - for name, value := range attrs { - if strings.Contains(name, "capability") { - assert.Contains(t, fakeCapabilities, name) - } else { - assert.Equal(t, expectedAttributes[name], value) - } - } - assert.Equal(t, len(containerInstanceTags), len(req.Tags), "Wrong number of tags") - assert.Equal(t, len(platformDevices), len(req.PlatformDevices), "Wrong number of devices") - reqTags := extractTagsMapFromRegisterContainerInstanceInput(req) - for k, v := range reqTags { - assert.Contains(t, containerInstanceTagsMap, k) - assert.Equal(t, containerInstanceTagsMap[k], v) - } - }).Return(&ecs.RegisterContainerInstanceOutput{ - ContainerInstance: &ecs.ContainerInstance{ - ContainerInstanceArn: aws.String("registerArn"), - Attributes: buildAttributeList(fakeCapabilities, expectedAttributes)}}, - nil), - ) - - arn, availabilityzone, err := client.RegisterContainerInstance("", capabilities, - containerInstanceTags, registrationToken, platformDevices, "test:arn:outpost") - require.NoError(t, err) - assert.Equal(t, "registerArn", arn) - assert.Equal(t, "us-west-2b", availabilityzone) - }) - } -} - -func TestReRegisterContainerInstance(t *testing.T) { - additionalAttributes := map[string]string{"my_custom_attribute": "Custom_Value1", - "my_other_custom_attribute": "Custom_Value2", - "attribute_name_with_no_value": "", - } - - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(mockCtrl) - client, mc, _ := NewMockClient(mockCtrl, mockEC2Metadata, additionalAttributes) - - fakeCapabilities := []string{"capability1", "capability2"} - expectedAttributes := map[string]string{ - "ecs.os-type": config.OSType, - "ecs.os-family": config.GetOSFamily(), - "ecs.availability-zone": "us-west-2b", - "ecs.outpost-arn": "test:arn:outpost", - } - for i := range fakeCapabilities { - expectedAttributes[fakeCapabilities[i]] = "" - } - capabilities := buildAttributeList(fakeCapabilities, nil) - - gomock.InOrder( - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return("instanceIdentityDocument", nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return("signature", nil), - mc.EXPECT().RegisterContainerInstance(gomock.Any()).Do(func(req *ecs.RegisterContainerInstanceInput) { - assert.Equal(t, "arn:test", *req.ContainerInstanceArn, "Wrong container instance ARN") - assert.Equal(t, configuredCluster, *req.Cluster, "Wrong cluster") - assert.Equal(t, registrationToken, *req.ClientToken, "Wrong client token") - assert.Equal(t, iid, *req.InstanceIdentityDocument, "Wrong IID") - assert.Equal(t, iidSignature, *req.InstanceIdentityDocumentSignature, "Wrong IID sig") - assert.Equal(t, 4, len(req.TotalResources), "Wrong length of TotalResources") - resource, ok := findResource(req.TotalResources, "PORTS_UDP") - assert.True(t, ok, `Could not find resource "PORTS_UDP"`) - assert.Equal(t, "STRINGSET", *resource.Type, `Wrong type for resource "PORTS_UDP"`) - // "ecs.os-type", ecs.os-family, ecs.outpost-arn and the 2 that we specified as additionalAttributes - assert.Equal(t, 5, len(req.Attributes), "Wrong number of Attributes") - reqAttributes := func() map[string]string { - rv := make(map[string]string, len(req.Attributes)) - for i := range req.Attributes { - rv[aws.StringValue(req.Attributes[i].Name)] = aws.StringValue(req.Attributes[i].Value) - } - return rv - }() - for k, v := range reqAttributes { - assert.Contains(t, expectedAttributes, k) - assert.Equal(t, expectedAttributes[k], v) - } - assert.Equal(t, len(containerInstanceTags), len(req.Tags), "Wrong number of tags") - reqTags := extractTagsMapFromRegisterContainerInstanceInput(req) - for k, v := range reqTags { - assert.Contains(t, containerInstanceTagsMap, k) - assert.Equal(t, containerInstanceTagsMap[k], v) - } - }).Return(&ecs.RegisterContainerInstanceOutput{ - ContainerInstance: &ecs.ContainerInstance{ - ContainerInstanceArn: aws.String("registerArn"), - Attributes: buildAttributeList(fakeCapabilities, expectedAttributes), - }}, - nil), - ) - - arn, availabilityzone, err := client.RegisterContainerInstance("arn:test", capabilities, - containerInstanceTags, registrationToken, nil, "test:arn:outpost") - - assert.NoError(t, err) - assert.Equal(t, "registerArn", arn) - assert.Equal(t, "us-west-2b", availabilityzone, "availabilityZone is incorrect") -} - -// TestRegisterContainerInstanceWithNegativeResource tests the registeration should fail with negative resource -func TestRegisterContainerInstanceWithNegativeResource(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - _, mem := getCpuAndMemory() - mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(mockCtrl) - client := NewECSClient(credentials.AnonymousCredentials, - &config.Config{Cluster: configuredCluster, - AWSRegion: "us-east-1", - ReservedMemory: uint16(mem) + 1, - }, mockEC2Metadata) - mockSDK := mock_api.NewMockECSSDK(mockCtrl) - mockSubmitStateSDK := mock_api.NewMockECSSubmitStateSDK(mockCtrl) - client.(*APIECSClient).SetSDK(mockSDK) - client.(*APIECSClient).SetSubmitStateChangeSDK(mockSubmitStateSDK) - - gomock.InOrder( - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return("instanceIdentityDocument", nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return("signature", nil), - ) - _, _, err := client.RegisterContainerInstance("", nil, nil, - "", nil, "") - assert.Error(t, err, "Register resource with negative value should cause registration fail") -} - -func TestRegisterContainerInstanceWithEmptyTags(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(mockCtrl) - client, mc, _ := NewMockClient(mockCtrl, mockEC2Metadata, nil) - - expectedAttributes := map[string]string{ - "ecs.os-type": config.OSType, - "ecs.os-family": config.GetOSFamily(), - "my_custom_attribute": "Custom_Value1", - "my_other_custom_attribute": "Custom_Value2", - } - - fakeCapabilities := []string{"capability1", "capability2"} - - gomock.InOrder( - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return("instanceIdentityDocument", nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return("signature", nil), - mc.EXPECT().RegisterContainerInstance(gomock.Any()).Do(func(req *ecs.RegisterContainerInstanceInput) { - assert.Nil(t, req.Tags) - }).Return(&ecs.RegisterContainerInstanceOutput{ - ContainerInstance: &ecs.ContainerInstance{ - ContainerInstanceArn: aws.String("registerArn"), - Attributes: buildAttributeList(fakeCapabilities, expectedAttributes)}}, - nil), - ) - - _, _, err := client.RegisterContainerInstance("", nil, make([]*ecs.Tag, 0), - "", nil, "") - assert.NoError(t, err) -} - -func TestValidateRegisteredAttributes(t *testing.T) { - origAttributes := []*ecs.Attribute{ - {Name: aws.String("foo"), Value: aws.String("bar")}, - {Name: aws.String("baz"), Value: aws.String("quux")}, - {Name: aws.String("no_value"), Value: aws.String("")}, - } - actualAttributes := []*ecs.Attribute{ - {Name: aws.String("baz"), Value: aws.String("quux")}, - {Name: aws.String("foo"), Value: aws.String("bar")}, - {Name: aws.String("no_value"), Value: aws.String("")}, - {Name: aws.String("ecs.internal-attribute"), Value: aws.String("some text")}, - } - assert.NoError(t, validateRegisteredAttributes(origAttributes, actualAttributes)) - - origAttributes = append(origAttributes, &ecs.Attribute{Name: aws.String("abc"), Value: aws.String("xyz")}) - assert.Error(t, validateRegisteredAttributes(origAttributes, actualAttributes)) -} - -func findResource(resources []*ecs.Resource, name string) (*ecs.Resource, bool) { - for _, resource := range resources { - if name == *resource.Name { - return resource, true - } - } - return nil, false -} - -func TestRegisterBlankCluster(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(mockCtrl) - - // Test the special 'empty cluster' behavior of creating 'default' - client := NewECSClient(credentials.AnonymousCredentials, - &config.Config{ - Cluster: "", - AWSRegion: "us-east-1", - }, - mockEC2Metadata) - mc := mock_api.NewMockECSSDK(mockCtrl) - client.(*APIECSClient).SetSDK(mc) - - expectedAttributes := map[string]string{ - "ecs.os-type": config.OSType, - "ecs.os-family": config.GetOSFamily(), - } - defaultCluster := config.DefaultClusterName - gomock.InOrder( - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return("instanceIdentityDocument", nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return("signature", nil), - mc.EXPECT().RegisterContainerInstance(gomock.Any()).Return(nil, awserr.New("ClientException", "Cluster not found.", errors.New("Cluster not found."))), - mc.EXPECT().CreateCluster(&ecs.CreateClusterInput{ClusterName: &defaultCluster}).Return(&ecs.CreateClusterOutput{Cluster: &ecs.Cluster{ClusterName: &defaultCluster}}, nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return("instanceIdentityDocument", nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return("signature", nil), - mc.EXPECT().RegisterContainerInstance(gomock.Any()).Do(func(req *ecs.RegisterContainerInstanceInput) { - if *req.Cluster != config.DefaultClusterName { - t.Errorf("Wrong cluster: %v", *req.Cluster) - } - if *req.InstanceIdentityDocument != iid { - t.Errorf("Wrong IID: %v", *req.InstanceIdentityDocument) - } - if *req.InstanceIdentityDocumentSignature != iidSignature { - t.Errorf("Wrong IID sig: %v", *req.InstanceIdentityDocumentSignature) - } - }).Return(&ecs.RegisterContainerInstanceOutput{ - ContainerInstance: &ecs.ContainerInstance{ - ContainerInstanceArn: aws.String("registerArn"), - Attributes: buildAttributeList(nil, expectedAttributes)}}, - nil), - ) - - arn, availabilityzone, err := client.RegisterContainerInstance("", nil, nil, - "", nil, "") - if err != nil { - t.Errorf("Should not be an error: %v", err) - } - if arn != "registerArn" { - t.Errorf("Wrong arn: %v", arn) - } - if availabilityzone != "" { - t.Errorf("wrong availability zone: %v", availabilityzone) - } -} - -func TestRegisterBlankClusterNotCreatingClusterWhenErrorNotClusterNotFound(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(mockCtrl) - - // Test the special 'empty cluster' behavior of creating 'default' - client := NewECSClient(credentials.AnonymousCredentials, - &config.Config{ - Cluster: "", - AWSRegion: "us-east-1", - }, - mockEC2Metadata) - mc := mock_api.NewMockECSSDK(mockCtrl) - client.(*APIECSClient).SetSDK(mc) - - expectedAttributes := map[string]string{ - "ecs.os-type": config.OSType, - "ecs.os-family": config.GetOSFamily(), - } - - gomock.InOrder( - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return("instanceIdentityDocument", nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return("signature", nil), - mc.EXPECT().RegisterContainerInstance(gomock.Any()).Return(nil, awserr.New("ClientException", "Invalid request.", errors.New("Invalid request."))), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).Return("instanceIdentityDocument", nil), - mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).Return("signature", nil), - mc.EXPECT().RegisterContainerInstance(gomock.Any()).Do(func(req *ecs.RegisterContainerInstanceInput) { - if *req.Cluster != config.DefaultClusterName { - t.Errorf("Wrong cluster: %v", *req.Cluster) - } - if *req.InstanceIdentityDocument != iid { - t.Errorf("Wrong IID: %v", *req.InstanceIdentityDocument) - } - if *req.InstanceIdentityDocumentSignature != iidSignature { - t.Errorf("Wrong IID sig: %v", *req.InstanceIdentityDocumentSignature) - } - }).Return(&ecs.RegisterContainerInstanceOutput{ - ContainerInstance: &ecs.ContainerInstance{ - ContainerInstanceArn: aws.String("registerArn"), - Attributes: buildAttributeList(nil, expectedAttributes)}}, - nil), - ) - - arn, _, err := client.RegisterContainerInstance("", nil, nil, "", - nil, "") - assert.NoError(t, err, "Should not return error") - assert.Equal(t, "registerArn", arn, "Wrong arn") -} - -func TestDiscoverTelemetryEndpoint(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - expectedEndpoint := "http://127.0.0.1" - mc.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(&ecs.DiscoverPollEndpointOutput{TelemetryEndpoint: &expectedEndpoint}, nil) - endpoint, err := client.DiscoverTelemetryEndpoint("containerInstance") - if err != nil { - t.Error("Error getting telemetry endpoint: ", err) - } - if expectedEndpoint != endpoint { - t.Errorf("Expected telemetry endpoint(%s) != endpoint(%s)", expectedEndpoint, endpoint) - } -} - -func TestDiscoverTelemetryEndpointError(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - mc.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil, fmt.Errorf("Error getting endpoint")) - _, err := client.DiscoverTelemetryEndpoint("containerInstance") - if err == nil { - t.Error("Expected error getting telemetry endpoint, didn't get any") - } -} - -func TestDiscoverNilTelemetryEndpoint(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - pollEndpoint := "http://127.0.0.1" - mc.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(&ecs.DiscoverPollEndpointOutput{Endpoint: &pollEndpoint}, nil) - _, err := client.DiscoverTelemetryEndpoint("containerInstance") - if err == nil { - t.Error("Expected error getting telemetry endpoint with old response") - } -} - -func TestDiscoverServiceConnectEndpoint(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - expectedEndpoint := "http://127.0.0.1" - mc.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(&ecs.DiscoverPollEndpointOutput{ServiceConnectEndpoint: &expectedEndpoint}, nil) - endpoint, err := client.DiscoverServiceConnectEndpoint("containerInstance") - if err != nil { - t.Error("Error getting service connect endpoint: ", err) - } - if expectedEndpoint != endpoint { - t.Errorf("Expected telemetry endpoint(%s) != endpoint(%s)", expectedEndpoint, endpoint) - } -} - -func TestDiscoverServiceConnectEndpointError(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - mc.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil, fmt.Errorf("Error getting endpoint")) - _, err := client.DiscoverServiceConnectEndpoint("containerInstance") - if err == nil { - t.Error("Expected error getting service connect endpoint, didn't get any") - } -} - -func TestDiscoverNilServiceConnectEndpoint(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - pollEndpoint := "http://127.0.0.1" - mc.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(&ecs.DiscoverPollEndpointOutput{Endpoint: &pollEndpoint}, nil) - _, err := client.DiscoverServiceConnectEndpoint("containerInstance") - if err == nil { - t.Error("Expected error getting service connect endpoint with old response") - } -} - -func TestUpdateContainerInstancesState(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - - instanceARN := "myInstanceARN" - status := "DRAINING" - mc.EXPECT().UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{ - ContainerInstances: []*string{aws.String(instanceARN)}, - Status: aws.String(status), - Cluster: aws.String(configuredCluster), - }).Return(&ecs.UpdateContainerInstancesStateOutput{}, nil) - - err := client.UpdateContainerInstancesState(instanceARN, status) - assert.NoError(t, err, fmt.Sprintf("Unexpected error calling UpdateContainerInstancesState: %s", err)) -} - -func TestUpdateContainerInstancesStateError(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - - instanceARN := "myInstanceARN" - status := "DRAINING" - mc.EXPECT().UpdateContainerInstancesState(&ecs.UpdateContainerInstancesStateInput{ - ContainerInstances: []*string{aws.String(instanceARN)}, - Status: aws.String(status), - Cluster: aws.String(configuredCluster), - }).Return(nil, fmt.Errorf("ERROR")) - - err := client.UpdateContainerInstancesState(instanceARN, status) - assert.Error(t, err, "Expected an error calling UpdateContainerInstancesState but got nil") -} - -func TestGetResourceTags(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - - instanceARN := "myInstanceARN" - mc.EXPECT().ListTagsForResource(&ecs.ListTagsForResourceInput{ - ResourceArn: aws.String(instanceARN), - }).Return(&ecs.ListTagsForResourceOutput{ - Tags: containerInstanceTags, - }, nil) - - _, err := client.GetResourceTags(instanceARN) - assert.NoError(t, err, fmt.Sprintf("Unexpected error calling GetResourceTags: %s", err)) -} - -func TestGetResourceTagsError(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - client, mc, _ := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - - instanceARN := "myInstanceARN" - mc.EXPECT().ListTagsForResource(&ecs.ListTagsForResourceInput{ - ResourceArn: aws.String(instanceARN), - }).Return(nil, fmt.Errorf("ERROR")) - - _, err := client.GetResourceTags(instanceARN) - assert.Error(t, err, "Expected an error calling GetResourceTags but got nil") -} - -func TestDiscoverPollEndpointCacheHit(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) - client := &APIECSClient{ - credentialProvider: credentials.AnonymousCredentials, - config: &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-east-1", - }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpointCache: pollEndpointCache, - } - - pollEndpoint := "http://127.0.0.1" - pollEndpointCache.EXPECT().Get("containerInstance").Return( - &ecs.DiscoverPollEndpointOutput{ - Endpoint: aws.String(pollEndpoint), - }, false, true) - output, err := client.discoverPollEndpoint("containerInstance") - if err != nil { - t.Fatalf("Error in discoverPollEndpoint: %v", err) - } - if aws.StringValue(output.Endpoint) != pollEndpoint { - t.Errorf("Mismatch in poll endpoint: %s != %s", aws.StringValue(output.Endpoint), pollEndpoint) - } -} - -func TestDiscoverPollEndpointCacheMiss(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) - client := &APIECSClient{ - credentialProvider: credentials.AnonymousCredentials, - config: &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-east-1", - }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpointCache: pollEndpointCache, - } - pollEndpoint := "http://127.0.0.1" - pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{ - Endpoint: &pollEndpoint, - } - - gomock.InOrder( - pollEndpointCache.EXPECT().Get("containerInstance").Return(nil, false, false), - mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(pollEndpointOutput, nil), - pollEndpointCache.EXPECT().Set("containerInstance", pollEndpointOutput), - ) - - output, err := client.discoverPollEndpoint("containerInstance") - if err != nil { - t.Fatalf("Error in discoverPollEndpoint: %v", err) - } - if aws.StringValue(output.Endpoint) != pollEndpoint { - t.Errorf("Mismatch in poll endpoint: %s != %s", aws.StringValue(output.Endpoint), pollEndpoint) - } -} - -func TestDiscoverPollEndpointExpiredButDPEFailed(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpointCache := mock_async.NewMockTTLCache(mockCtrl) - client := &APIECSClient{ - credentialProvider: credentials.AnonymousCredentials, - config: &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-east-1", - }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpointCache: pollEndpointCache, - } - pollEndpoint := "http://127.0.0.1" - pollEndpointOutput := &ecs.DiscoverPollEndpointOutput{ - Endpoint: &pollEndpoint, - } - - gomock.InOrder( - pollEndpointCache.EXPECT().Get("containerInstance").Return(pollEndpointOutput, true, false), - mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(nil, fmt.Errorf("error!")), - ) - - output, err := client.discoverPollEndpoint("containerInstance") - if err != nil { - t.Fatalf("Error in discoverPollEndpoint: %v", err) - } - if aws.StringValue(output.Endpoint) != pollEndpoint { - t.Errorf("Mismatch in poll endpoint: %s != %s", aws.StringValue(output.Endpoint), pollEndpoint) - } -} - -func TestDiscoverTelemetryEndpointAfterPollEndpointCacheHit(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - mockSDK := mock_api.NewMockECSSDK(mockCtrl) - pollEndpointCache := async.NewTTLCache(&async.TTL{Duration: 10 * time.Minute}) - client := &APIECSClient{ - credentialProvider: credentials.AnonymousCredentials, - config: &config.Config{ - Cluster: configuredCluster, - AWSRegion: "us-east-1", - }, - standardClient: mockSDK, - ec2metadata: ec2.NewBlackholeEC2MetadataClient(), - pollEndpointCache: pollEndpointCache, - } - - pollEndpoint := "http://127.0.0.1" - mockSDK.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return( - &ecs.DiscoverPollEndpointOutput{ - Endpoint: &pollEndpoint, - TelemetryEndpoint: &pollEndpoint, - }, nil) - endpoint, err := client.DiscoverPollEndpoint("containerInstance") - if err != nil { - t.Fatalf("Error in discoverPollEndpoint: %v", err) - } - if endpoint != pollEndpoint { - t.Errorf("Mismatch in poll endpoint: %s", endpoint) - } - telemetryEndpoint, err := client.DiscoverTelemetryEndpoint("containerInstance") - if err != nil { - t.Fatalf("Error in discoverTelemetryEndpoint: %v", err) - } - if telemetryEndpoint != pollEndpoint { - t.Errorf("Mismatch in poll endpoint: %s", endpoint) - } -} - -// TestSubmitTaskStateChangeWithENIAttachments tests the SubmitTaskStateChange API -// also send the Attachment Status -func TestSubmitTaskStateChangeWithENIAttachments(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - client, _, mockSubmitStateClient := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - mockSubmitStateClient.EXPECT().SubmitTaskStateChange(&taskSubmitInputMatcher{ - ecs.SubmitTaskStateChangeInput{ - Cluster: aws.String(configuredCluster), - Task: aws.String("task_arn"), - Attachments: []*ecs.AttachmentStateChange{ - { - AttachmentArn: aws.String("eni_arn"), - Status: aws.String("ATTACHED"), - }, - }, - }, - }) - - err := client.SubmitTaskStateChange(api.TaskStateChange{ - TaskARN: "task_arn", - Attachment: &ni.ENIAttachment{ - AttachmentInfo: attachment.AttachmentInfo{ - AttachmentARN: "eni_arn", - Status: attachment.AttachmentAttached, - }, - }, - }) - assert.NoError(t, err, "Unable to submit task state change with attachments") -} - -func TestSubmitTaskStateChangeWithoutAttachments(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - client, _, mockSubmitStateClient := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - mockSubmitStateClient.EXPECT().SubmitTaskStateChange(&taskSubmitInputMatcher{ - ecs.SubmitTaskStateChangeInput{ - Cluster: aws.String(configuredCluster), - Task: aws.String("task_arn"), - Reason: aws.String(""), - Status: aws.String("RUNNING"), - }, - }) - - err := client.SubmitTaskStateChange(api.TaskStateChange{ - TaskARN: "task_arn", - Status: apitaskstatus.TaskRunning, - }) - assert.NoError(t, err, "Unable to submit task state change with no attachments") -} - -func TestSubmitTaskStateChangeWithManagedAgents(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - client, _, mockSubmitStateClient := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - mockSubmitStateClient.EXPECT().SubmitTaskStateChange(&taskSubmitInputMatcher{ - ecs.SubmitTaskStateChangeInput{ - Cluster: aws.String(configuredCluster), - Task: aws.String("task_arn"), - Reason: aws.String(""), - Status: aws.String("RUNNING"), - ManagedAgents: testManagedAgents, - }, - }) - - err := client.SubmitTaskStateChange(api.TaskStateChange{ - TaskARN: "task_arn", - Status: apitaskstatus.TaskRunning, - }) - assert.NoError(t, err, "Unable to submit task state change with managed agents") -} - -// TestSubmitContainerStateChangeWhileTaskInPending tests the container state change was submitted -// when the task is still in pending state -func TestSubmitContainerStateChangeWhileTaskInPending(t *testing.T) { - mockCtrl := gomock.NewController(t) - defer mockCtrl.Finish() - - testCases := []struct { - taskStatus apitaskstatus.TaskStatus - }{ - { - apitaskstatus.TaskStatusNone, - }, - { - apitaskstatus.TaskPulled, - }, - { - apitaskstatus.TaskCreated, - }, - } - - taskStateChangePending := api.TaskStateChange{ - Status: apitaskstatus.TaskCreated, - TaskARN: "arn", - Containers: []api.ContainerStateChange{ - { - TaskArn: "arn", - ContainerName: "container", - RuntimeID: "runtimeid", - Container: &apicontainer.Container{ - NetworkModeUnsafe: testNetworkName, - }, - Status: apicontainerstatus.ContainerRunning, - }, - }, - } - - for _, tc := range testCases { - t.Run(fmt.Sprintf("TaskStatus: %s", tc.taskStatus.String()), func(t *testing.T) { - taskStateChangePending.Status = tc.taskStatus - client, _, mockSubmitStateClient := NewMockClient(mockCtrl, ec2.NewBlackholeEC2MetadataClient(), nil) - mockSubmitStateClient.EXPECT().SubmitTaskStateChange(&taskSubmitInputMatcher{ - ecs.SubmitTaskStateChangeInput{ - Cluster: strptr(configuredCluster), - Task: strptr("arn"), - Status: strptr("PENDING"), - Reason: strptr(""), - Containers: []*ecs.ContainerStateChange{ - { - ContainerName: strptr("container"), - RuntimeId: strptr("runtimeid"), - Status: strptr("RUNNING"), - NetworkBindings: []*ecs.NetworkBinding{}, - }, - }, - }, - }) - err := client.SubmitTaskStateChange(taskStateChangePending) - assert.NoError(t, err) - }) - } -} - -func extractTagsMapFromRegisterContainerInstanceInput(req *ecs.RegisterContainerInstanceInput) map[string]string { - tagsMap := make(map[string]string, len(req.Tags)) - for i := range req.Tags { - tagsMap[aws.StringValue(req.Tags[i].Key)] = aws.StringValue(req.Tags[i].Value) - } - return tagsMap -} - -func getTestContainerStateChange() api.ContainerStateChange { - testContainer := &apicontainer.Container{ - Name: "cont", - NetworkModeUnsafe: testNetworkName, - Ports: []apicontainer.PortBinding{ - { - ContainerPort: 10, - HostPort: 10, - Protocol: apicontainer.TransportProtocolTCP, - }, - { - ContainerPort: 12, - HostPort: 12, - Protocol: apicontainer.TransportProtocolUDP, - }, - { - ContainerPort: 15, - Protocol: apicontainer.TransportProtocolTCP, - }, - { - ContainerPortRange: "21-22", - Protocol: apicontainer.TransportProtocolUDP, - }, - { - ContainerPortRange: "96-97", - Protocol: apicontainer.TransportProtocolTCP, - }, - }, - ContainerHasPortRange: true, - ContainerPortSet: map[int]struct{}{ - 10: {}, - 12: {}, - 15: {}, - }, - ContainerPortRangeMap: map[string]string{ - "21-22": "60001-60002", - "96-97": "47001-47002", - }, - } - - testContainerStateChange := api.ContainerStateChange{ - TaskArn: "arn", - ContainerName: "cont", - Status: apicontainerstatus.ContainerRunning, - Container: testContainer, - PortBindings: []apicontainer.PortBinding{ - { - ContainerPort: 10, - HostPort: 10, - BindIP: "0.0.0.0", - Protocol: apicontainer.TransportProtocolTCP, - }, - { - ContainerPort: 12, - HostPort: 12, - BindIP: "1.2.3.4", - Protocol: apicontainer.TransportProtocolUDP, - }, - { - ContainerPort: 15, - HostPort: 20, - BindIP: "5.6.7.8", - Protocol: apicontainer.TransportProtocolTCP, - }, - { - ContainerPort: 21, - HostPort: 60001, - BindIP: "::", - Protocol: apicontainer.TransportProtocolUDP, - }, - { - ContainerPort: 22, - HostPort: 60002, - BindIP: "::", - Protocol: apicontainer.TransportProtocolUDP, - }, - { - ContainerPort: 96, - HostPort: 47001, - BindIP: "0.0.0.0", - Protocol: apicontainer.TransportProtocolTCP, - }, - { - ContainerPort: 97, - HostPort: 47002, - BindIP: "0.0.0.0", - Protocol: apicontainer.TransportProtocolTCP, - }, - }, - } - - return testContainerStateChange -} - -func TestGetNetworkBindings(t *testing.T) { - testContainerStateChange := getTestContainerStateChange() - expectedNetworkBindings := []*ecs.NetworkBinding{ - { - BindIP: strptr("0.0.0.0"), - ContainerPort: int64ptr(intptr(10)), - HostPort: int64ptr(intptr(10)), - Protocol: strptr("tcp"), - }, - { - BindIP: strptr("1.2.3.4"), - ContainerPort: int64ptr(intptr(12)), - HostPort: int64ptr(intptr(12)), - Protocol: strptr("udp"), - }, - { - BindIP: strptr("5.6.7.8"), - ContainerPort: int64ptr(intptr(15)), - HostPort: int64ptr(intptr(20)), - Protocol: strptr("tcp"), - }, - { - BindIP: strptr("::"), - ContainerPortRange: strptr("21-22"), - HostPortRange: strptr("60001-60002"), - Protocol: strptr("udp"), - }, - { - BindIP: strptr("0.0.0.0"), - ContainerPortRange: strptr("96-97"), - HostPortRange: strptr("47001-47002"), - Protocol: strptr("tcp"), - }, - } - - networkBindings := getNetworkBindings(testContainerStateChange, false) - assert.ElementsMatch(t, expectedNetworkBindings, networkBindings) -} diff --git a/agent/api/ecsclient/retry_handler_test.go b/agent/api/ecsclient/retry_handler_test.go deleted file mode 100644 index 127f52b86d0..00000000000 --- a/agent/api/ecsclient/retry_handler_test.go +++ /dev/null @@ -1,58 +0,0 @@ -//go:build unit -// +build unit - -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package ecsclient - -import ( - "errors" - "net/http" - "testing" - "time" - - "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/defaults" -) - -func TestOneDayRetrier(t *testing.T) { - stateChangeClient := newSubmitStateChangeClient(defaults.Config()) - - request, _ := stateChangeClient.SubmitContainerStateChangeRequest(&ecs.SubmitContainerStateChangeInput{}) - - retrier := stateChangeClient.Retryer - - var totalDelay time.Duration - for retries := 0; retries < retrier.MaxRetries(); retries++ { - request.Error = errors.New("") - request.Retryable = aws.Bool(true) - request.HTTPResponse = &http.Response{StatusCode: 500} - if request.WillRetry() && request.IsErrorRetryable() { - totalDelay += retrier.RetryRules(request) - request.RetryCount++ - } - } - - request.Error = errors.New("") - request.Retryable = aws.Bool(true) - request.HTTPResponse = &http.Response{StatusCode: 500} - if request.WillRetry() { - t.Errorf("Expected request to not be retried after %v retries", retrier.MaxRetries()) - } - - if totalDelay > 25*time.Hour || totalDelay < 23*time.Hour { - t.Errorf("Expected accumulated retry delay to be roughly 24 hours; was %v", totalDelay) - } -} diff --git a/agent/api/ecsclient/utils_amd64_test.go b/agent/api/ecsclient/utils_amd64_test.go deleted file mode 100644 index 5bb95300c54..00000000000 --- a/agent/api/ecsclient/utils_amd64_test.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build amd64 && unit -// +build amd64,unit - -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package ecsclient - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGetCPUArch(t *testing.T) { - assert.Equal(t, "x86_64", getCPUArch()) -} diff --git a/agent/api/ecsclient/utils_arm64_test.go b/agent/api/ecsclient/utils_arm64_test.go deleted file mode 100644 index 1a46f13425b..00000000000 --- a/agent/api/ecsclient/utils_arm64_test.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build arm64 && unit -// +build arm64,unit - -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package ecsclient - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestGetCPUArch(t *testing.T) { - assert.Equal(t, "arm64", getCPUArch()) -} diff --git a/agent/api/generate_mocks.go b/agent/api/generate_mocks.go deleted file mode 100644 index c19794a9d82..00000000000 --- a/agent/api/generate_mocks.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package api - -//go:generate mockgen -destination=mocks/api_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/api ECSSDK,ECSSubmitStateSDK,ECSClient diff --git a/agent/api/interface.go b/agent/api/interface.go deleted file mode 100644 index 1cb6546200a..00000000000 --- a/agent/api/interface.go +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package api - -import ( - "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" -) - -// ECSClient is an interface over the ECSSDK interface which abstracts away some -// details around constructing the request and reading the response down to the -// parts the agent cares about. -// For example, the ever-present 'Cluster' member is abstracted out so that it -// may be configured once and used throughout transparently. -type ECSClient interface { - // RegisterContainerInstance calculates the appropriate resources, creates - // the default cluster if necessary, and returns the registered - // ContainerInstanceARN if successful. Supplying a non-empty container - // instance ARN allows a container instance to update its registered - // resources. - RegisterContainerInstance(existingContainerInstanceArn string, - attributes []*ecs.Attribute, tags []*ecs.Tag, registrationToken string, platformDevices []*ecs.PlatformDevice, - outpostARN string) (string, string, error) - // SubmitTaskStateChange sends a state change and returns an error - // indicating if it was submitted - SubmitTaskStateChange(change TaskStateChange) error - // SubmitContainerStateChange sends a state change and returns an error - // indicating if it was submitted - SubmitContainerStateChange(change ContainerStateChange) error - // SubmitAttachmentStateChange sends an attachment state change and returns an error - // indicating if it was submitted - SubmitAttachmentStateChange(change AttachmentStateChange) error - // DiscoverPollEndpoint takes a ContainerInstanceARN and returns the - // endpoint at which this Agent should contact ACS - DiscoverPollEndpoint(containerInstanceArn string) (string, error) - // DiscoverTelemetryEndpoint takes a ContainerInstanceARN and returns the - // endpoint at which this Agent should contact Telemetry Service - DiscoverTelemetryEndpoint(containerInstanceArn string) (string, error) - // DiscoverServiceConnectEndpoint takes a ContainerInstanceARN and returns the - // endpoint at which this Agent should contact ServiceConnect - DiscoverServiceConnectEndpoint(containerInstanceArn string) (string, error) - // GetResourceTags retrieves the Tags associated with a certain resource - GetResourceTags(resourceArn string) ([]*ecs.Tag, error) - // UpdateContainerInstancesState updates the given container Instance ID with - // the given status. Only valid statuses are ACTIVE and DRAINING. - UpdateContainerInstancesState(instanceARN, status string) error - // GetHostResources retrieves a map that map the resource name to the corresponding resource - GetHostResources() (map[string]*ecs.Resource, error) -} - -// ECSSDK is an interface that specifies the subset of the AWS Go SDK's ECS -// client that the Agent uses. This interface is meant to allow injecting a -// mock for testing. -type ECSSDK interface { - CreateCluster(*ecs.CreateClusterInput) (*ecs.CreateClusterOutput, error) - RegisterContainerInstance(*ecs.RegisterContainerInstanceInput) (*ecs.RegisterContainerInstanceOutput, error) - DiscoverPollEndpoint(*ecs.DiscoverPollEndpointInput) (*ecs.DiscoverPollEndpointOutput, error) - ListTagsForResource(*ecs.ListTagsForResourceInput) (*ecs.ListTagsForResourceOutput, error) - UpdateContainerInstancesState(input *ecs.UpdateContainerInstancesStateInput) (*ecs.UpdateContainerInstancesStateOutput, error) -} - -// ECSSubmitStateSDK is an interface with customized ecs client that -// implements the SubmitTaskStateChange and SubmitContainerStateChange -type ECSSubmitStateSDK interface { - SubmitContainerStateChange(*ecs.SubmitContainerStateChangeInput) (*ecs.SubmitContainerStateChangeOutput, error) - SubmitTaskStateChange(*ecs.SubmitTaskStateChangeInput) (*ecs.SubmitTaskStateChangeOutput, error) - SubmitAttachmentStateChanges(*ecs.SubmitAttachmentStateChangesInput) (*ecs.SubmitAttachmentStateChangesOutput, error) -} diff --git a/agent/api/metadata_getter.go b/agent/api/metadata_getter.go index 2a7ddf707fd..594b0e3435b 100644 --- a/agent/api/metadata_getter.go +++ b/agent/api/metadata_getter.go @@ -25,7 +25,7 @@ type containerMetadataGetter struct { container *apicontainer.Container } -func NewContainerMetadataGetter(container *apicontainer.Container) *containerMetadataGetter { +func newContainerMetadataGetter(container *apicontainer.Container) *containerMetadataGetter { return &containerMetadataGetter{ container: container, } @@ -58,7 +58,7 @@ type taskMetadataGetter struct { task *apitask.Task } -func NewTaskMetadataGetter(task *apitask.Task) *taskMetadataGetter { +func newTaskMetadataGetter(task *apitask.Task) *taskMetadataGetter { return &taskMetadataGetter{ task: task, } diff --git a/agent/api/metadata_getter_test.go b/agent/api/metadata_getter_test.go index f89cccfc568..61b25cae170 100644 --- a/agent/api/metadata_getter_test.go +++ b/agent/api/metadata_getter_test.go @@ -42,7 +42,7 @@ func TestTaskStateChangeMetadataGetter(t *testing.T) { ExecutionStoppedAtUnsafe: t3, } - metadataGetter := NewTaskMetadataGetter(task) + metadataGetter := newTaskMetadataGetter(task) change := &ecs.TaskStateChange{ MetadataGetter: metadataGetter, @@ -67,7 +67,7 @@ func TestContainerStateChangeMetadataGetter(t *testing.T) { SentStatusUnsafe: apicontainerstatus.ContainerRunning, } - metadataGetter := NewContainerMetadataGetter(container) + metadataGetter := newContainerMetadataGetter(container) change := &ecs.ContainerStateChange{ MetadataGetter: metadataGetter, diff --git a/agent/api/mocks/api_mocks.go b/agent/api/mocks/api_mocks.go deleted file mode 100644 index 57990a766d7..00000000000 --- a/agent/api/mocks/api_mocks.go +++ /dev/null @@ -1,363 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. -// - -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/amazon-ecs-agent/agent/api (interfaces: ECSSDK,ECSSubmitStateSDK,ECSClient) - -// Package mock_api is a generated GoMock package. -package mock_api - -import ( - reflect "reflect" - - api "github.com/aws/amazon-ecs-agent/agent/api" - ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" - gomock "github.com/golang/mock/gomock" -) - -// MockECSSDK is a mock of ECSSDK interface. -type MockECSSDK struct { - ctrl *gomock.Controller - recorder *MockECSSDKMockRecorder -} - -// MockECSSDKMockRecorder is the mock recorder for MockECSSDK. -type MockECSSDKMockRecorder struct { - mock *MockECSSDK -} - -// NewMockECSSDK creates a new mock instance. -func NewMockECSSDK(ctrl *gomock.Controller) *MockECSSDK { - mock := &MockECSSDK{ctrl: ctrl} - mock.recorder = &MockECSSDKMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockECSSDK) EXPECT() *MockECSSDKMockRecorder { - return m.recorder -} - -// CreateCluster mocks base method. -func (m *MockECSSDK) CreateCluster(arg0 *ecs.CreateClusterInput) (*ecs.CreateClusterOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateCluster", arg0) - ret0, _ := ret[0].(*ecs.CreateClusterOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateCluster indicates an expected call of CreateCluster. -func (mr *MockECSSDKMockRecorder) CreateCluster(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateCluster", reflect.TypeOf((*MockECSSDK)(nil).CreateCluster), arg0) -} - -// DiscoverPollEndpoint mocks base method. -func (m *MockECSSDK) DiscoverPollEndpoint(arg0 *ecs.DiscoverPollEndpointInput) (*ecs.DiscoverPollEndpointOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DiscoverPollEndpoint", arg0) - ret0, _ := ret[0].(*ecs.DiscoverPollEndpointOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DiscoverPollEndpoint indicates an expected call of DiscoverPollEndpoint. -func (mr *MockECSSDKMockRecorder) DiscoverPollEndpoint(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverPollEndpoint", reflect.TypeOf((*MockECSSDK)(nil).DiscoverPollEndpoint), arg0) -} - -// ListTagsForResource mocks base method. -func (m *MockECSSDK) ListTagsForResource(arg0 *ecs.ListTagsForResourceInput) (*ecs.ListTagsForResourceOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "ListTagsForResource", arg0) - ret0, _ := ret[0].(*ecs.ListTagsForResourceOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// ListTagsForResource indicates an expected call of ListTagsForResource. -func (mr *MockECSSDKMockRecorder) ListTagsForResource(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTagsForResource", reflect.TypeOf((*MockECSSDK)(nil).ListTagsForResource), arg0) -} - -// RegisterContainerInstance mocks base method. -func (m *MockECSSDK) RegisterContainerInstance(arg0 *ecs.RegisterContainerInstanceInput) (*ecs.RegisterContainerInstanceOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RegisterContainerInstance", arg0) - ret0, _ := ret[0].(*ecs.RegisterContainerInstanceOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// RegisterContainerInstance indicates an expected call of RegisterContainerInstance. -func (mr *MockECSSDKMockRecorder) RegisterContainerInstance(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterContainerInstance", reflect.TypeOf((*MockECSSDK)(nil).RegisterContainerInstance), arg0) -} - -// UpdateContainerInstancesState mocks base method. -func (m *MockECSSDK) UpdateContainerInstancesState(arg0 *ecs.UpdateContainerInstancesStateInput) (*ecs.UpdateContainerInstancesStateOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateContainerInstancesState", arg0) - ret0, _ := ret[0].(*ecs.UpdateContainerInstancesStateOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateContainerInstancesState indicates an expected call of UpdateContainerInstancesState. -func (mr *MockECSSDKMockRecorder) UpdateContainerInstancesState(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateContainerInstancesState", reflect.TypeOf((*MockECSSDK)(nil).UpdateContainerInstancesState), arg0) -} - -// MockECSSubmitStateSDK is a mock of ECSSubmitStateSDK interface. -type MockECSSubmitStateSDK struct { - ctrl *gomock.Controller - recorder *MockECSSubmitStateSDKMockRecorder -} - -// MockECSSubmitStateSDKMockRecorder is the mock recorder for MockECSSubmitStateSDK. -type MockECSSubmitStateSDKMockRecorder struct { - mock *MockECSSubmitStateSDK -} - -// NewMockECSSubmitStateSDK creates a new mock instance. -func NewMockECSSubmitStateSDK(ctrl *gomock.Controller) *MockECSSubmitStateSDK { - mock := &MockECSSubmitStateSDK{ctrl: ctrl} - mock.recorder = &MockECSSubmitStateSDKMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockECSSubmitStateSDK) EXPECT() *MockECSSubmitStateSDKMockRecorder { - return m.recorder -} - -// SubmitAttachmentStateChanges mocks base method. -func (m *MockECSSubmitStateSDK) SubmitAttachmentStateChanges(arg0 *ecs.SubmitAttachmentStateChangesInput) (*ecs.SubmitAttachmentStateChangesOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubmitAttachmentStateChanges", arg0) - ret0, _ := ret[0].(*ecs.SubmitAttachmentStateChangesOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SubmitAttachmentStateChanges indicates an expected call of SubmitAttachmentStateChanges. -func (mr *MockECSSubmitStateSDKMockRecorder) SubmitAttachmentStateChanges(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitAttachmentStateChanges", reflect.TypeOf((*MockECSSubmitStateSDK)(nil).SubmitAttachmentStateChanges), arg0) -} - -// SubmitContainerStateChange mocks base method. -func (m *MockECSSubmitStateSDK) SubmitContainerStateChange(arg0 *ecs.SubmitContainerStateChangeInput) (*ecs.SubmitContainerStateChangeOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubmitContainerStateChange", arg0) - ret0, _ := ret[0].(*ecs.SubmitContainerStateChangeOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SubmitContainerStateChange indicates an expected call of SubmitContainerStateChange. -func (mr *MockECSSubmitStateSDKMockRecorder) SubmitContainerStateChange(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitContainerStateChange", reflect.TypeOf((*MockECSSubmitStateSDK)(nil).SubmitContainerStateChange), arg0) -} - -// SubmitTaskStateChange mocks base method. -func (m *MockECSSubmitStateSDK) SubmitTaskStateChange(arg0 *ecs.SubmitTaskStateChangeInput) (*ecs.SubmitTaskStateChangeOutput, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubmitTaskStateChange", arg0) - ret0, _ := ret[0].(*ecs.SubmitTaskStateChangeOutput) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// SubmitTaskStateChange indicates an expected call of SubmitTaskStateChange. -func (mr *MockECSSubmitStateSDKMockRecorder) SubmitTaskStateChange(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitTaskStateChange", reflect.TypeOf((*MockECSSubmitStateSDK)(nil).SubmitTaskStateChange), arg0) -} - -// MockECSClient is a mock of ECSClient interface. -type MockECSClient struct { - ctrl *gomock.Controller - recorder *MockECSClientMockRecorder -} - -// MockECSClientMockRecorder is the mock recorder for MockECSClient. -type MockECSClientMockRecorder struct { - mock *MockECSClient -} - -// NewMockECSClient creates a new mock instance. -func NewMockECSClient(ctrl *gomock.Controller) *MockECSClient { - mock := &MockECSClient{ctrl: ctrl} - mock.recorder = &MockECSClientMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockECSClient) EXPECT() *MockECSClientMockRecorder { - return m.recorder -} - -// DiscoverPollEndpoint mocks base method. -func (m *MockECSClient) DiscoverPollEndpoint(arg0 string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DiscoverPollEndpoint", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DiscoverPollEndpoint indicates an expected call of DiscoverPollEndpoint. -func (mr *MockECSClientMockRecorder) DiscoverPollEndpoint(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverPollEndpoint", reflect.TypeOf((*MockECSClient)(nil).DiscoverPollEndpoint), arg0) -} - -// DiscoverServiceConnectEndpoint mocks base method. -func (m *MockECSClient) DiscoverServiceConnectEndpoint(arg0 string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DiscoverServiceConnectEndpoint", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DiscoverServiceConnectEndpoint indicates an expected call of DiscoverServiceConnectEndpoint. -func (mr *MockECSClientMockRecorder) DiscoverServiceConnectEndpoint(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverServiceConnectEndpoint", reflect.TypeOf((*MockECSClient)(nil).DiscoverServiceConnectEndpoint), arg0) -} - -// DiscoverTelemetryEndpoint mocks base method. -func (m *MockECSClient) DiscoverTelemetryEndpoint(arg0 string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DiscoverTelemetryEndpoint", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DiscoverTelemetryEndpoint indicates an expected call of DiscoverTelemetryEndpoint. -func (mr *MockECSClientMockRecorder) DiscoverTelemetryEndpoint(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverTelemetryEndpoint", reflect.TypeOf((*MockECSClient)(nil).DiscoverTelemetryEndpoint), arg0) -} - -// GetHostResources mocks base method. -func (m *MockECSClient) GetHostResources() (map[string]*ecs.Resource, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetHostResources") - ret0, _ := ret[0].(map[string]*ecs.Resource) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetHostResources indicates an expected call of GetHostResources. -func (mr *MockECSClientMockRecorder) GetHostResources() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHostResources", reflect.TypeOf((*MockECSClient)(nil).GetHostResources)) -} - -// GetResourceTags mocks base method. -func (m *MockECSClient) GetResourceTags(arg0 string) ([]*ecs.Tag, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetResourceTags", arg0) - ret0, _ := ret[0].([]*ecs.Tag) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetResourceTags indicates an expected call of GetResourceTags. -func (mr *MockECSClientMockRecorder) GetResourceTags(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResourceTags", reflect.TypeOf((*MockECSClient)(nil).GetResourceTags), arg0) -} - -// RegisterContainerInstance mocks base method. -func (m *MockECSClient) RegisterContainerInstance(arg0 string, arg1 []*ecs.Attribute, arg2 []*ecs.Tag, arg3 string, arg4 []*ecs.PlatformDevice, arg5 string) (string, string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "RegisterContainerInstance", arg0, arg1, arg2, arg3, arg4, arg5) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(string) - ret2, _ := ret[2].(error) - return ret0, ret1, ret2 -} - -// RegisterContainerInstance indicates an expected call of RegisterContainerInstance. -func (mr *MockECSClientMockRecorder) RegisterContainerInstance(arg0, arg1, arg2, arg3, arg4, arg5 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RegisterContainerInstance", reflect.TypeOf((*MockECSClient)(nil).RegisterContainerInstance), arg0, arg1, arg2, arg3, arg4, arg5) -} - -// SubmitAttachmentStateChange mocks base method. -func (m *MockECSClient) SubmitAttachmentStateChange(arg0 api.AttachmentStateChange) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubmitAttachmentStateChange", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SubmitAttachmentStateChange indicates an expected call of SubmitAttachmentStateChange. -func (mr *MockECSClientMockRecorder) SubmitAttachmentStateChange(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitAttachmentStateChange", reflect.TypeOf((*MockECSClient)(nil).SubmitAttachmentStateChange), arg0) -} - -// SubmitContainerStateChange mocks base method. -func (m *MockECSClient) SubmitContainerStateChange(arg0 api.ContainerStateChange) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubmitContainerStateChange", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SubmitContainerStateChange indicates an expected call of SubmitContainerStateChange. -func (mr *MockECSClientMockRecorder) SubmitContainerStateChange(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitContainerStateChange", reflect.TypeOf((*MockECSClient)(nil).SubmitContainerStateChange), arg0) -} - -// SubmitTaskStateChange mocks base method. -func (m *MockECSClient) SubmitTaskStateChange(arg0 api.TaskStateChange) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SubmitTaskStateChange", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// SubmitTaskStateChange indicates an expected call of SubmitTaskStateChange. -func (mr *MockECSClientMockRecorder) SubmitTaskStateChange(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SubmitTaskStateChange", reflect.TypeOf((*MockECSClient)(nil).SubmitTaskStateChange), arg0) -} - -// UpdateContainerInstancesState mocks base method. -func (m *MockECSClient) UpdateContainerInstancesState(arg0, arg1 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateContainerInstancesState", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 -} - -// UpdateContainerInstancesState indicates an expected call of UpdateContainerInstancesState. -func (mr *MockECSClientMockRecorder) UpdateContainerInstancesState(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateContainerInstancesState", reflect.TypeOf((*MockECSClient)(nil).UpdateContainerInstancesState), arg0, arg1) -} diff --git a/agent/api/statechange.go b/agent/api/statechange.go index 5183bb226dd..7df2b984082 100644 --- a/agent/api/statechange.go +++ b/agent/api/statechange.go @@ -23,14 +23,25 @@ import ( "github.com/aws/amazon-ecs-agent/agent/statechange" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + ecsmodel "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils" "github.com/aws/aws-sdk-go/aws" + "github.com/docker/go-connections/nat" "github.com/pkg/errors" ) +const ( + // ecsMaxNetworkBindingsLength is the maximum length of the ecs.NetworkBindings list sent as part of the + // container state change payload. Currently, this is enforced only when containerPortRanges are requested. + ecsMaxNetworkBindingsLength = 100 +) + // ContainerStateChange represents a state change that needs to be sent to the // SubmitContainerStateChange API type ContainerStateChange struct { @@ -99,6 +110,7 @@ type TaskStateChange struct { // AttachmentStateChange represents a state change that needs to be sent to the // SubmitAttachmentStateChanges API type AttachmentStateChange struct { + // Attachment is the attachment object to send Attachment attachment.Attachment } @@ -255,6 +267,33 @@ func (c *ContainerStateChange) String() string { return res } +// ToECSAgent converts the agent module level ContainerStateChange to ecs-agent module level ContainerStateChange. +func (c *ContainerStateChange) ToECSAgent() (*ecs.ContainerStateChange, error) { + pl, err := buildContainerStateChangePayload(*c) + if err != nil { + logger.Error("Could not convert agent container state change to ecs-agent container state change", + logger.Fields{ + "agentContainerStateChange": c.String(), + field.Error: err, + }) + return nil, err + } else if pl == nil { + return nil, nil + } + + return &ecs.ContainerStateChange{ + TaskArn: c.TaskArn, + RuntimeID: aws.StringValue(pl.RuntimeId), + ContainerName: c.ContainerName, + Status: c.Status, + ImageDigest: aws.StringValue(pl.ImageDigest), + Reason: aws.StringValue(pl.Reason), + ExitCode: utils.Int64PtrToIntPtr(pl.ExitCode), + NetworkBindings: pl.NetworkBindings, + MetadataGetter: newContainerMetadataGetter(c.Container), + }, nil +} + // String returns a human readable string representation of ManagedAgentStateChange func (m *ManagedAgentStateChange) String() string { res := fmt.Sprintf("containerName=%s managedAgentName=%s managedAgentStatus=%s", m.Container.Name, m.Name, m.Status.String()) @@ -349,16 +388,59 @@ func (change *TaskStateChange) String() string { return res } +// ToECSAgent converts the agent module level TaskStateChange to ecs-agent module level TaskStateChange. +func (change *TaskStateChange) ToECSAgent() (*ecs.TaskStateChange, error) { + output := &ecs.TaskStateChange{ + Attachment: change.Attachment, + TaskARN: change.TaskARN, + Status: change.Status, + Reason: change.Reason, + PullStartedAt: change.PullStartedAt, + PullStoppedAt: change.PullStoppedAt, + ExecutionStoppedAt: change.ExecutionStoppedAt, + MetadataGetter: newTaskMetadataGetter(change.Task), + } + + for _, managedAgentEvent := range change.ManagedAgents { + if mgspl := buildManagedAgentStateChangePayload(managedAgentEvent); mgspl != nil { + output.ManagedAgents = append(output.ManagedAgents, mgspl) + } + } + + containerEvents := make([]*ecsmodel.ContainerStateChange, len(change.Containers)) + for i, containerEvent := range change.Containers { + payload, err := buildContainerStateChangePayload(containerEvent) + if err != nil { + logger.Error("Could not convert agent task state change to ecs-agent task state change", logger.Fields{ + "agentTaskStateChange": change.String(), + field.Error: err, + }) + return nil, err + } + containerEvents[i] = payload + } + output.Containers = containerEvents + + return output, nil +} + // String returns a human readable string representation of this object func (change *AttachmentStateChange) String() string { if change.Attachment != nil { - return fmt.Sprintf("%s -> %v, %s", change.Attachment.GetAttachmentARN(), change.Attachment.GetAttachmentStatus(), - change.Attachment.String()) + return fmt.Sprintf("%s -> %v, %s", change.Attachment.GetAttachmentARN(), + change.Attachment.GetAttachmentStatus(), change.Attachment.String()) } return "" } +// ToECSAgent converts the agent module level AttachmentStateChange to ecs-agent module level AttachmentStateChange. +func (change *AttachmentStateChange) ToECSAgent() *ecs.AttachmentStateChange { + return &ecs.AttachmentStateChange{ + Attachment: change.Attachment, + } +} + // GetEventType returns an enum identifying the event type func (ContainerStateChange) GetEventType() statechange.EventType { return statechange.ContainerEvent @@ -377,3 +459,132 @@ func (ts TaskStateChange) GetEventType() statechange.EventType { func (AttachmentStateChange) GetEventType() statechange.EventType { return statechange.AttachmentEvent } + +func buildManagedAgentStateChangePayload(change ManagedAgentStateChange) *ecsmodel.ManagedAgentStateChange { + if !change.Status.ShouldReportToBackend() { + logger.Warn("Not submitting unsupported managed agent state", logger.Fields{ + field.Status: change.Status.String(), + field.ContainerName: change.Container.Name, + field.TaskARN: change.TaskArn, + }) + return nil + } + return &ecsmodel.ManagedAgentStateChange{ + ManagedAgentName: aws.String(change.Name), + ContainerName: aws.String(change.Container.Name), + Status: aws.String(change.Status.String()), + Reason: aws.String(change.Reason), + } +} + +func buildContainerStateChangePayload(change ContainerStateChange) (*ecsmodel.ContainerStateChange, error) { + if change.ContainerName == "" { + return nil, fmt.Errorf("container state change has no container name") + } + statechange := &ecsmodel.ContainerStateChange{ + ContainerName: aws.String(change.ContainerName), + } + if change.RuntimeID != "" { + statechange.RuntimeId = aws.String(change.RuntimeID) + } + if change.Reason != "" { + statechange.Reason = aws.String(change.Reason) + } + if change.ImageDigest != "" { + statechange.ImageDigest = aws.String(change.ImageDigest) + } + + stat := change.Status.String() + if stat != apicontainerstatus.ContainerStopped.String() && stat != apicontainerstatus.ContainerRunning.String() { + logger.Warn("Not submitting unsupported upstream container state", logger.Fields{ + field.Status: stat, + field.ContainerName: change.ContainerName, + field.TaskARN: change.TaskArn, + }) + return nil, nil + } + if stat == "DEAD" { + stat = apicontainerstatus.ContainerStopped.String() + } + statechange.Status = aws.String(stat) + + if change.ExitCode != nil { + exitCode := int64(aws.IntValue(change.ExitCode)) + statechange.ExitCode = aws.Int64(exitCode) + } + + networkBindings := getNetworkBindings(change) + // we enforce a limit on the no. of network bindings for containers with at-least 1 port range requested. + // this limit is enforced by ECS, and we fail early and don't call SubmitContainerStateChange. + if change.Container.HasPortRange() && len(networkBindings) > ecsMaxNetworkBindingsLength { + return nil, fmt.Errorf("no. of network bindings %v is more than the maximum supported no. %v, "+ + "container: %s "+"task: %s", len(networkBindings), ecsMaxNetworkBindingsLength, change.ContainerName, change.TaskArn) + } + statechange.NetworkBindings = networkBindings + + return statechange, nil +} + +// ProtocolBindIP used to store protocol and bindIP information associated to a particular host port +type ProtocolBindIP struct { + protocol string + bindIP string +} + +// getNetworkBindings returns the list of networkingBindings, sent to ECS as part of the container state change payload +func getNetworkBindings(change ContainerStateChange) []*ecsmodel.NetworkBinding { + networkBindings := []*ecsmodel.NetworkBinding{} + // hostPortToProtocolBindIPMap is a map to store protocol and bindIP information associated to host ports + // that belong to a range. This is used in case when there are multiple protocol/bindIP combinations associated to a + // port binding. example: when both IPv4 and IPv6 bindIPs are populated by docker. + hostPortToProtocolBindIPMap := map[int64][]ProtocolBindIP{} + + // ContainerPortSet consists of singular ports, and ports that belong to a range, but for which we were not able to + // find contiguous host ports and ask docker to pick instead. + containerPortSet := change.Container.GetContainerPortSet() + // each entry in the ContainerPortRangeMap implies that we found a contiguous host port range for the same + containerPortRangeMap := change.Container.GetContainerPortRangeMap() + + for _, binding := range change.PortBindings { + hostPort := int64(binding.HostPort) + containerPort := int64(binding.ContainerPort) + bindIP := binding.BindIP + protocol := binding.Protocol.String() + + // create network binding for each containerPort that exists in the singular ContainerPortSet + // for container ports that belong to a range, we'll have 1 consolidated network binding for the range + if _, ok := containerPortSet[int(containerPort)]; ok { + networkBindings = append(networkBindings, &ecsmodel.NetworkBinding{ + BindIP: aws.String(bindIP), + ContainerPort: aws.Int64(containerPort), + HostPort: aws.Int64(hostPort), + Protocol: aws.String(protocol), + }) + } else { + // populate hostPortToProtocolBindIPMap – this is used below when we construct network binding for ranges. + hostPortToProtocolBindIPMap[hostPort] = append(hostPortToProtocolBindIPMap[hostPort], + ProtocolBindIP{ + protocol: protocol, + bindIP: bindIP, + }) + } + } + + for containerPortRange, hostPortRange := range containerPortRangeMap { + // we check for protocol and bindIP information associated to any one of the host ports from the hostPortRange, + // all ports belonging to the same range share this information. + hostPort, _, _ := nat.ParsePortRangeToInt(hostPortRange) + if val, ok := hostPortToProtocolBindIPMap[int64(hostPort)]; ok { + for _, v := range val { + networkBindings = append(networkBindings, &ecsmodel.NetworkBinding{ + BindIP: aws.String(v.bindIP), + ContainerPortRange: aws.String(containerPortRange), + HostPortRange: aws.String(hostPortRange), + Protocol: aws.String(v.protocol), + }) + } + } + } + + return networkBindings +} diff --git a/agent/api/statechange_test.go b/agent/api/statechange_test.go index e33f4badc93..11fe5f3b506 100644 --- a/agent/api/statechange_test.go +++ b/agent/api/statechange_test.go @@ -24,10 +24,11 @@ import ( apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" "github.com/aws/amazon-ecs-agent/agent/api/serviceconnect" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" - execcmd "github.com/aws/amazon-ecs-agent/agent/engine/execcmd" + "github.com/aws/amazon-ecs-agent/agent/engine/execcmd" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" - + "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" ) @@ -361,3 +362,136 @@ func TestNewManagedAgentChangeEvent(t *testing.T) { }) } } + +func TestGetNetworkBindings(t *testing.T) { + testContainerStateChange := getTestContainerStateChange() + expectedNetworkBindings := []*ecs.NetworkBinding{ + { + BindIP: aws.String("0.0.0.0"), + ContainerPort: aws.Int64(10), + HostPort: aws.Int64(10), + Protocol: aws.String("tcp"), + }, + { + BindIP: aws.String("1.2.3.4"), + ContainerPort: aws.Int64(12), + HostPort: aws.Int64(12), + Protocol: aws.String("udp"), + }, + { + BindIP: aws.String("5.6.7.8"), + ContainerPort: aws.Int64(15), + HostPort: aws.Int64(20), + Protocol: aws.String("tcp"), + }, + { + BindIP: aws.String("::"), + ContainerPortRange: aws.String("21-22"), + HostPortRange: aws.String("60001-60002"), + Protocol: aws.String("udp"), + }, + { + BindIP: aws.String("0.0.0.0"), + ContainerPortRange: aws.String("96-97"), + HostPortRange: aws.String("47001-47002"), + Protocol: aws.String("tcp"), + }, + } + + networkBindings := getNetworkBindings(testContainerStateChange) + assert.ElementsMatch(t, expectedNetworkBindings, networkBindings) +} + +func getTestContainerStateChange() ContainerStateChange { + testContainer := &apicontainer.Container{ + Name: "cont", + NetworkModeUnsafe: "bridge", + Ports: []apicontainer.PortBinding{ + { + ContainerPort: 10, + HostPort: 10, + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: 12, + HostPort: 12, + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPort: 15, + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPortRange: "21-22", + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPortRange: "96-97", + Protocol: apicontainer.TransportProtocolTCP, + }, + }, + ContainerHasPortRange: true, + ContainerPortSet: map[int]struct{}{ + 10: {}, + 12: {}, + 15: {}, + }, + ContainerPortRangeMap: map[string]string{ + "21-22": "60001-60002", + "96-97": "47001-47002", + }, + } + + testContainerStateChange := ContainerStateChange{ + TaskArn: "arn", + ContainerName: "cont", + Status: apicontainerstatus.ContainerRunning, + Container: testContainer, + PortBindings: []apicontainer.PortBinding{ + { + ContainerPort: 10, + HostPort: 10, + BindIP: "0.0.0.0", + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: 12, + HostPort: 12, + BindIP: "1.2.3.4", + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPort: 15, + HostPort: 20, + BindIP: "5.6.7.8", + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: 21, + HostPort: 60001, + BindIP: "::", + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPort: 22, + HostPort: 60002, + BindIP: "::", + Protocol: apicontainer.TransportProtocolUDP, + }, + { + ContainerPort: 96, + HostPort: 47001, + BindIP: "0.0.0.0", + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: 97, + HostPort: 47002, + BindIP: "0.0.0.0", + Protocol: apicontainer.TransportProtocolTCP, + }, + }, + } + + return testContainerStateChange +} diff --git a/agent/api/task/task_test.go b/agent/api/task/task_test.go index 8d020d5880f..fa7eacbec79 100644 --- a/agent/api/task/task_test.go +++ b/agent/api/task/task_test.go @@ -160,6 +160,39 @@ func TestDockerConfigPortBinding(t *testing.T) { } } +func TestDockerConfigPortBindingContainerPortIsZero(t *testing.T) { + testTask := &Task{ + Containers: []*apicontainer.Container{ + { + Name: "ContainerHavingPortBindingWithContainerPortZero", + Ports: []apicontainer.PortBinding{ + { + ContainerPort: 0, + HostPort: 10, + BindIP: "", + Protocol: apicontainer.TransportProtocolTCP, + }, + { + ContainerPort: 0, + HostPort: 20, + BindIP: "", + Protocol: apicontainer.TransportProtocolUDP, + }, + }, + }, + }, + } + + dockerContainerConfig, err := testTask.DockerConfig(testTask.Containers[0], defaultDockerClientAPIVersion) + assert.Nil(t, err) + + // Ensure that port zero is not included in the set of container ports that are exposed for the container. + _, ok := dockerContainerConfig.ExposedPorts["0/tcp"] + assert.False(t, ok, "Unexpectedly could get exposed ports 0/tcp") + _, ok = dockerContainerConfig.ExposedPorts["0/udp"] + assert.False(t, ok, "Unexpectedly could get exposed ports 0/udp") +} + func TestDockerHostConfigCPUShareZero(t *testing.T) { testTask := &Task{ Containers: []*apicontainer.Container{ diff --git a/agent/app/agent.go b/agent/app/agent.go index 2f85fafb732..477ade44b63 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -22,8 +22,6 @@ import ( agentacs "github.com/aws/amazon-ecs-agent/agent/acs/session" "github.com/aws/amazon-ecs-agent/agent/acs/updater" - "github.com/aws/amazon-ecs-agent/agent/api" - "github.com/aws/amazon-ecs-agent/agent/api/ecsclient" "github.com/aws/amazon-ecs-agent/agent/app/factory" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/containermetadata" @@ -32,7 +30,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi" "github.com/aws/amazon-ecs-agent/agent/dockerclient/sdkclientfactory" dockerdoctor "github.com/aws/amazon-ecs-agent/agent/doctor" // for Docker specific container instance health checks - ebs "github.com/aws/amazon-ecs-agent/agent/ebs" + "github.com/aws/amazon-ecs-agent/agent/ebs" "github.com/aws/amazon-ecs-agent/agent/ecscni" "github.com/aws/amazon-ecs-agent/agent/engine" dm "github.com/aws/amazon-ecs-agent/agent/engine/daemonmanager" @@ -56,7 +54,9 @@ import ( "github.com/aws/amazon-ecs-agent/agent/version" acsclient "github.com/aws/amazon-ecs-agent/ecs-agent/acs/client" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + ecsclient "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client" + ecsmodel "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/instancecreds" @@ -288,8 +288,21 @@ func (agent *ecsAgent) start() int { credentialsManager := credentials.NewManager() state := dockerstate.NewTaskEngineState() imageManager := engine.NewImageManager(agent.cfg, agent.dockerClient, state) - client := ecsclient.NewECSClient(agent.credentialProvider, agent.cfg, agent.ec2MetadataClient) - + cfgAccessor, err := config.NewAgentConfigAccessor(agent.cfg) + if err != nil { + logger.Critical("Unable to create new agent config accessor", logger.Fields{ + field.Error: err, + }) + return exitcodes.ExitError + } + client, err := ecsclient.NewECSClient(agent.credentialProvider, cfgAccessor, agent.ec2MetadataClient, + version.String(), ecsclient.WithIPv6PortBindingExcluded(true)) + if err != nil { + logger.Critical("Unable to create new ECS client", logger.Fields{ + field.Error: err, + }) + return exitcodes.ExitError + } agent.initializeResourceFields(credentialsManager) return agent.doStart(containerChangeEventStream, credentialsManager, state, imageManager, client, execcmd.NewManager()) } @@ -301,7 +314,7 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre credentialsManager credentials.Manager, state dockerstate.TaskEngineState, imageManager engine.ImageManager, - client api.ECSClient, + client ecs.ECSClient, execCmdMgr execcmd.Manager) int { // check docker version >= 1.9.0, exit agent if older if exitcode, ok := agent.verifyRequiredDockerVersion(); !ok { @@ -330,13 +343,13 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre // Find GPUs (if any) on the instance platformDevices := agent.getPlatformDevices() for _, device := range platformDevices { - if *device.Type == ecs.PlatformDeviceTypeGpu { + if *device.Type == ecsmodel.PlatformDeviceTypeGpu { gpuIDs = append(gpuIDs, *device.Id) } } } - hostResources["GPU"] = &ecs.Resource{ + hostResources["GPU"] = &ecsmodel.Resource{ Name: utils.Strptr("GPU"), Type: utils.Strptr("STRINGSET"), StringSetValue: aws.StringSlice(gpuIDs), @@ -370,7 +383,7 @@ func (agent *ecsAgent) doStart(containerChangeEventStream *eventstream.EventStre seelog.Errorf("Failed to load pause container: %v", loadPauseErr) } - var vpcSubnetAttributes []*ecs.Attribute + var vpcSubnetAttributes []*ecsmodel.Attribute // Check if Task ENI is enabled if agent.cfg.TaskENIEnabled.Enabled() { // check pause container image load @@ -557,7 +570,7 @@ func (agent *ecsAgent) newTaskEngine(containerChangeEventStream *eventstream.Eve credentialsManager credentials.Manager, state dockerstate.TaskEngineState, imageManager engine.ImageManager, - hostResources map[string]*ecs.Resource, + hostResources map[string]*ecsmodel.Resource, execCmdMgr execcmd.Manager, serviceConnectManager engineserviceconnect.Manager, daemonManagers map[string]dm.DaemonManager) (engine.TaskEngine, string, error) { @@ -734,8 +747,8 @@ func (agent *ecsAgent) newStateManager( // constructVPCSubnetAttributes returns vpc and subnet IDs of the instance as // an attribute list -func (agent *ecsAgent) constructVPCSubnetAttributes() []*ecs.Attribute { - return []*ecs.Attribute{ +func (agent *ecsAgent) constructVPCSubnetAttributes() []*ecsmodel.Attribute { + return []*ecsmodel.Attribute{ { Name: aws.String(vpcIDAttributeName), Value: aws.String(agent.vpc), @@ -781,8 +794,8 @@ func (agent *ecsAgent) loadManagedDaemonImage(dm dm.DaemonManager, imageManager // registerContainerInstance registers the container instance ID for the ECS Agent func (agent *ecsAgent) registerContainerInstance( - client api.ECSClient, - additionalAttributes []*ecs.Attribute) error { + client ecs.ECSClient, + additionalAttributes []*ecsmodel.Attribute) error { // Preflight request to make sure they're good if preflightCreds, err := agent.credentialProvider.Get(); err != nil || preflightCreds.AccessKeyID == "" { seelog.Errorf("Error getting valid credentials: %s", err) @@ -830,7 +843,7 @@ func (agent *ecsAgent) registerContainerInstance( if retriable, ok := err.(apierrors.Retriable); ok && !retriable.Retry() { return err } - if utils.IsAWSErrorCodeEqual(err, ecs.ErrCodeInvalidParameterException) { + if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeInvalidParameterException) { logger.Critical("Instance registration attempt with an invalid parameter", logger.Fields{ field.Error: err, }) @@ -860,8 +873,8 @@ func (agent *ecsAgent) registerContainerInstance( // reregisterContainerInstance registers a container instance that has already been // registered with ECS. This is for cases where the ECS Agent is being restored // from a check point. -func (agent *ecsAgent) reregisterContainerInstance(client api.ECSClient, capabilities []*ecs.Attribute, - tags []*ecs.Tag, registrationToken string, platformDevices []*ecs.PlatformDevice, outpostARN string) error { +func (agent *ecsAgent) reregisterContainerInstance(client ecs.ECSClient, capabilities []*ecsmodel.Attribute, + tags []*ecsmodel.Tag, registrationToken string, platformDevices []*ecsmodel.PlatformDevice, outpostARN string) error { _, availabilityZone, err := client.RegisterContainerInstance(agent.containerInstanceARN, capabilities, tags, registrationToken, platformDevices, outpostARN) @@ -898,7 +911,7 @@ func (agent *ecsAgent) startAsyncRoutines( imageManager engine.ImageManager, taskEngine engine.TaskEngine, deregisterInstanceEventStream *eventstream.EventStream, - client api.ECSClient, + client ecs.ECSClient, taskHandler *eventhandler.TaskHandler, attachmentEventHandler *eventhandler.AttachmentEventHandler, state dockerstate.TaskEngineState, @@ -955,7 +968,7 @@ func (agent *ecsAgent) startAsyncRoutines( go session.Start(agent.ctx) } -func (agent *ecsAgent) startSpotInstanceDrainingPoller(ctx context.Context, client api.ECSClient) { +func (agent *ecsAgent) startSpotInstanceDrainingPoller(ctx context.Context, client ecs.ECSClient) { for !agent.spotInstanceDrainingPoller(client) { select { case <-ctx.Done(): @@ -968,7 +981,7 @@ func (agent *ecsAgent) startSpotInstanceDrainingPoller(ctx context.Context, clie // spotInstanceDrainingPoller returns true if spot instance interruption has been // set AND the container instance state is successfully updated to DRAINING. -func (agent *ecsAgent) spotInstanceDrainingPoller(client api.ECSClient) bool { +func (agent *ecsAgent) spotInstanceDrainingPoller(client ecs.ECSClient) bool { // this endpoint 404s unless a interruption has been set, so expect failure in most cases. resp, err := agent.ec2MetadataClient.SpotInstanceAction() if err == nil { @@ -1008,7 +1021,7 @@ func (agent *ecsAgent) startACSSession( credentialsManager credentials.Manager, taskEngine engine.TaskEngine, deregisterInstanceEventStream *eventstream.EventStream, - client api.ECSClient, + client ecs.ECSClient, state dockerstate.TaskEngineState, taskHandler *eventhandler.TaskHandler, doctor *doctor.Doctor) int { @@ -1107,7 +1120,7 @@ func (agent *ecsAgent) verifyRequiredDockerVersion() (int, bool) { } // getContainerInstanceTagsFromEC2API will retrieve the tags of this instance remotely. -func (agent *ecsAgent) getContainerInstanceTagsFromEC2API() ([]*ecs.Tag, error) { +func (agent *ecsAgent) getContainerInstanceTagsFromEC2API() ([]*ecsmodel.Tag, error) { // Get instance ID from ec2 metadata client. instanceID, err := agent.ec2MetadataClient.InstanceID() if err != nil { @@ -1119,7 +1132,7 @@ func (agent *ecsAgent) getContainerInstanceTagsFromEC2API() ([]*ecs.Tag, error) // mergeTags will merge the local tags and ec2 tags, for the overlap part, ec2 tags // will be overridden by local tags. -func mergeTags(localTags []*ecs.Tag, ec2Tags []*ecs.Tag) []*ecs.Tag { +func mergeTags(localTags []*ecsmodel.Tag, ec2Tags []*ecsmodel.Tag) []*ecsmodel.Tag { tagsMap := make(map[string]string) for _, ec2Tag := range ec2Tags { diff --git a/agent/app/agent_test.go b/agent/app/agent_test.go index b6239cb9c1a..02079e36883 100644 --- a/agent/app/agent_test.go +++ b/agent/app/agent_test.go @@ -25,7 +25,6 @@ import ( "testing" "time" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" mock_factory "github.com/aws/amazon-ecs-agent/agent/app/factory/mocks" app_mocks "github.com/aws/amazon-ecs-agent/agent/app/mocks" "github.com/aws/amazon-ecs-agent/agent/config" @@ -50,6 +49,7 @@ import ( mock_loader "github.com/aws/amazon-ecs-agent/agent/utils/loader/mocks" mock_mobypkgwrapper "github.com/aws/amazon-ecs-agent/agent/utils/mobypkgwrapper/mocks" "github.com/aws/amazon-ecs-agent/agent/version" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks" @@ -103,7 +103,7 @@ func setup(t *testing.T) (*gomock.Controller, *mock_credentials.MockManager, *mock_dockerstate.MockTaskEngineState, *mock_engine.MockImageManager, - *mock_api.MockECSClient, + *mock_ecs.MockECSClient, *mock_dockerapi.MockDockerClient, *mock_factory.MockStateManager, *mock_factory.MockSaveableOption, @@ -116,7 +116,7 @@ func setup(t *testing.T) (*gomock.Controller, mock_credentials.NewMockManager(ctrl), mock_dockerstate.NewMockTaskEngineState(ctrl), mock_engine.NewMockImageManager(ctrl), - mock_api.NewMockECSClient(ctrl), + mock_ecs.NewMockECSClient(ctrl), mock_dockerapi.NewMockDockerClient(ctrl), mock_factory.NewMockStateManager(ctrl), mock_factory.NewMockSaveableOption(ctrl), @@ -970,7 +970,7 @@ func TestReregisterContainerInstanceHappyPath(t *testing.T) { defer ctrl.Finish() mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) mockCredentialsProvider := app_mocks.NewMockProvider(ctrl) mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl) mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) @@ -1030,7 +1030,7 @@ func TestReregisterContainerInstanceInstanceTypeChanged(t *testing.T) { defer ctrl.Finish() mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) mockCredentialsProvider := app_mocks.NewMockProvider(ctrl) mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl) mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) @@ -1092,7 +1092,7 @@ func TestReregisterContainerInstanceAttributeError(t *testing.T) { defer ctrl.Finish() mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) mockCredentialsProvider := app_mocks.NewMockProvider(ctrl) mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl) mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) @@ -1152,7 +1152,7 @@ func TestReregisterContainerInstanceNonTerminalError(t *testing.T) { defer ctrl.Finish() mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) mockCredentialsProvider := app_mocks.NewMockProvider(ctrl) mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl) mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) @@ -1212,7 +1212,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetHappyPath(t *t defer ctrl.Finish() mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) mockCredentialsProvider := app_mocks.NewMockProvider(ctrl) mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl) mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) @@ -1271,7 +1271,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetCanRetryError( defer ctrl.Finish() mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) mockCredentialsProvider := app_mocks.NewMockProvider(ctrl) mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl) mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) @@ -1330,7 +1330,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetCannotRetryErr defer ctrl.Finish() mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) mockCredentialsProvider := app_mocks.NewMockProvider(ctrl) mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl) mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) @@ -1389,7 +1389,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetAttributeError defer ctrl.Finish() mockDockerClient := mock_dockerapi.NewMockDockerClient(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) mockCredentialsProvider := app_mocks.NewMockProvider(ctrl) mockMobyPlugins := mock_mobypkgwrapper.NewMockPlugins(ctrl) mockEC2Metadata := mock_ec2.NewMockEC2MetadataClient(ctrl) @@ -1677,7 +1677,7 @@ func TestSpotInstanceActionCheck_Sunny(t *testing.T) { ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl) ec2Client := mock_ec2.NewMockClient(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) for _, test := range tests { myARN := "myARN" @@ -1706,7 +1706,7 @@ func TestSpotInstanceActionCheck_Fail(t *testing.T) { ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl) ec2Client := mock_ec2.NewMockClient(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) for _, test := range tests { myARN := "myARN" @@ -1729,7 +1729,7 @@ func TestSpotInstanceActionCheck_NoInstanceActionYet(t *testing.T) { ec2MetadataClient := mock_ec2.NewMockEC2MetadataClient(ctrl) ec2Client := mock_ec2.NewMockClient(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) myARN := "myARN" agent := &ecsAgent{ diff --git a/agent/engine/serviceconnect/manager.go b/agent/engine/serviceconnect/manager.go index ce79e60f875..13ac1b4b592 100644 --- a/agent/engine/serviceconnect/manager.go +++ b/agent/engine/serviceconnect/manager.go @@ -14,11 +14,11 @@ package serviceconnect import ( - "github.com/aws/amazon-ecs-agent/agent/api" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/utils/loader" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" dockercontainer "github.com/docker/docker/api/types/container" ) @@ -29,7 +29,7 @@ type Manager interface { AugmentTaskContainer(task *apitask.Task, container *apicontainer.Container, hostConfig *dockercontainer.HostConfig) error CreateInstanceTask(config *config.Config) (*apitask.Task, error) AugmentInstanceContainer(task *apitask.Task, container *apicontainer.Container, hostConfig *dockercontainer.HostConfig) error - SetECSClient(client api.ECSClient, containerInstanceARN string) + SetECSClient(client ecs.ECSClient, containerInstanceARN string) GetLoadedAppnetVersion() (string, error) GetCapabilitiesForAppnetInterfaceVersion(appnetVersion string) ([]string, error) GetAppnetContainerTarballDir() string diff --git a/agent/engine/serviceconnect/manager_linux.go b/agent/engine/serviceconnect/manager_linux.go index 575e59821cc..699cf479296 100644 --- a/agent/engine/serviceconnect/manager_linux.go +++ b/agent/engine/serviceconnect/manager_linux.go @@ -27,7 +27,6 @@ import ( "github.com/aws/aws-sdk-go/aws" - "github.com/aws/amazon-ecs-agent/agent/api" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" apiserviceconnect "github.com/aws/amazon-ecs-agent/agent/api/serviceconnect" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" @@ -36,6 +35,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/taskresource" "github.com/aws/amazon-ecs-agent/agent/utils/loader" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" @@ -119,7 +119,7 @@ type manager struct { agentContainerTag string appnetInterfaceVersion string - ecsClient api.ECSClient + ecsClient ecs.ECSClient containerInstanceARN string } @@ -143,7 +143,7 @@ func NewManager() Manager { } } -func (m *manager) SetECSClient(client api.ECSClient, containerInstanceARN string) { +func (m *manager) SetECSClient(client ecs.ECSClient, containerInstanceARN string) { m.ecsClient = client m.containerInstanceARN = containerInstanceARN } diff --git a/agent/engine/serviceconnect/manager_other.go b/agent/engine/serviceconnect/manager_other.go index be0b7f118a2..33637565a6b 100644 --- a/agent/engine/serviceconnect/manager_other.go +++ b/agent/engine/serviceconnect/manager_other.go @@ -21,12 +21,12 @@ import ( "fmt" "runtime" - "github.com/aws/amazon-ecs-agent/agent/api" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi" "github.com/aws/amazon-ecs-agent/agent/utils/loader" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" "github.com/docker/docker/api/types" dockercontainer "github.com/docker/docker/api/types/container" @@ -61,7 +61,7 @@ func (*manager) IsLoaded(dockerClient dockerapi.DockerClient) (bool, error) { runtime.GOOS, runtime.GOARCH)) } -func (m *manager) SetECSClient(api.ECSClient, string) { +func (m *manager) SetECSClient(ecs.ECSClient, string) { } func (*manager) GetLoadedImageName() string { diff --git a/agent/engine/serviceconnect/mock/manager.go b/agent/engine/serviceconnect/mock/manager.go index c3c1e5baa09..f885b6ce7f5 100644 --- a/agent/engine/serviceconnect/mock/manager.go +++ b/agent/engine/serviceconnect/mock/manager.go @@ -22,11 +22,11 @@ import ( context "context" reflect "reflect" - api "github.com/aws/amazon-ecs-agent/agent/api" container "github.com/aws/amazon-ecs-agent/agent/api/container" task "github.com/aws/amazon-ecs-agent/agent/api/task" config "github.com/aws/amazon-ecs-agent/agent/config" dockerapi "github.com/aws/amazon-ecs-agent/agent/dockerclient/dockerapi" + ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" types "github.com/docker/docker/api/types" container0 "github.com/docker/docker/api/types/container" gomock "github.com/golang/mock/gomock" @@ -187,7 +187,7 @@ func (mr *MockManagerMockRecorder) LoadImage(arg0, arg1, arg2 interface{}) *gomo } // SetECSClient mocks base method. -func (m *MockManager) SetECSClient(arg0 api.ECSClient, arg1 string) { +func (m *MockManager) SetECSClient(arg0 ecs.ECSClient, arg1 string) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetECSClient", arg0, arg1) } diff --git a/agent/eventhandler/attachment_handler.go b/agent/eventhandler/attachment_handler.go index 63f1407c241..416e52d3f84 100644 --- a/agent/eventhandler/attachment_handler.go +++ b/agent/eventhandler/attachment_handler.go @@ -23,6 +23,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/statechange" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" "github.com/cihub/seelog" @@ -44,7 +45,7 @@ type AttachmentEventHandler struct { // lock is used to safely access the attachmentARNToHandler map lock sync.Mutex - client api.ECSClient + client ecs.ECSClient ctx context.Context } @@ -62,14 +63,14 @@ type attachmentHandler struct { // lock is used to ensure that the attached status of an attachment won't be sent multiple times lock sync.Mutex - client api.ECSClient + client ecs.ECSClient ctx context.Context } // NewAttachmentEventHandler returns a new AttachmentEventHandler object func NewAttachmentEventHandler(ctx context.Context, dataClient data.Client, - client api.ECSClient) *AttachmentEventHandler { + client ecs.ECSClient) *AttachmentEventHandler { return &AttachmentEventHandler{ ctx: ctx, client: client, @@ -136,7 +137,7 @@ func (handler *attachmentHandler) submitAttachmentEventOnce(attachmentChange *ap } seelog.Infof("AttachmentHandler: sending attachment state change: %s", attachmentChange.String()) - if err := handler.client.SubmitAttachmentStateChange(*attachmentChange); err != nil { + if err := handler.client.SubmitAttachmentStateChange(*attachmentChange.ToECSAgent()); err != nil { seelog.Errorf("AttachmentHandler: error submitting attachment state change [%s]: %v", attachmentChange.String(), err) return err } diff --git a/agent/eventhandler/attachment_handler_test.go b/agent/eventhandler/attachment_handler_test.go index d12cdcfbef4..19b0641b7fa 100644 --- a/agent/eventhandler/attachment_handler_test.go +++ b/agent/eventhandler/attachment_handler_test.go @@ -23,10 +23,11 @@ import ( "time" "github.com/aws/amazon-ecs-agent/agent/api" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" @@ -49,7 +50,7 @@ const ( func TestSendENIAttachmentEvent(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) attachmentEvent := eniAttachmentEvent(attachmentARN) @@ -65,7 +66,7 @@ func TestSendENIAttachmentEvent(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(nil).Do(func(change api.AttachmentStateChange) { + client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(nil).Do(func(change ecs.AttachmentStateChange) { assert.NotNil(t, change.Attachment) assert.Equal(t, attachmentARN, change.Attachment.GetAttachmentARN()) wg.Done() @@ -79,7 +80,7 @@ func TestSendENIAttachmentEvent(t *testing.T) { func TestSendResAttachmentEvent(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) attachmentEvent := resAttachmentEvent(attachmentARN) @@ -95,7 +96,7 @@ func TestSendResAttachmentEvent(t *testing.T) { var wg sync.WaitGroup wg.Add(1) - client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(nil).Do(func(change api.AttachmentStateChange) { + client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(nil).Do(func(change ecs.AttachmentStateChange) { assert.NotNil(t, change.Attachment) assert.Equal(t, attachmentARN, change.Attachment.GetAttachmentARN()) wg.Done() @@ -109,7 +110,7 @@ func TestSendResAttachmentEvent(t *testing.T) { func TestSendAttachmentEventRetries(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) attachmentEvent := eniAttachmentEvent(attachmentARN) @@ -133,7 +134,7 @@ func TestSendAttachmentEventRetries(t *testing.T) { gomock.InOrder( client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(retriable).Do(func(interface{}) { wg.Done() }), - client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(nil).Do(func(change api.AttachmentStateChange) { + client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(nil).Do(func(change ecs.AttachmentStateChange) { assert.NotNil(t, change.Attachment) assert.Equal(t, attachmentARN, change.Attachment.GetAttachmentARN()) wg.Done() @@ -148,7 +149,7 @@ func TestSendAttachmentEventRetries(t *testing.T) { func TestSendMutipleAttachmentEventsMixedAttachments(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) attachmentEvent1 := eniAttachmentEvent("attachmentARN1") attachmentEvent2 := resAttachmentEvent("attachmentARN2") @@ -170,7 +171,7 @@ func TestSendMutipleAttachmentEventsMixedAttachments(t *testing.T) { submittedAttachments := make(map[string]bool) // note down submitted attachments mapLock := sync.Mutex{} // lock to protect the above map - client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Times(3).Return(nil).Do(func(change api.AttachmentStateChange) { + client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Times(3).Return(nil).Do(func(change ecs.AttachmentStateChange) { mapLock.Lock() defer mapLock.Unlock() @@ -192,7 +193,7 @@ func TestSendMutipleAttachmentEventsMixedAttachments(t *testing.T) { func TestSubmitAttachmentEventSucceeds(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) dataClient := newTestDataClient(t) @@ -211,7 +212,7 @@ func TestSubmitAttachmentEventSucceeds(t *testing.T) { } defer cancel() - client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(nil).Do(func(change api.AttachmentStateChange) { + client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Return(nil).Do(func(change ecs.AttachmentStateChange) { assert.NotNil(t, change.Attachment) assert.Equal(t, attachmentARN, change.Attachment.GetAttachmentARN()) }) @@ -227,7 +228,7 @@ func TestSubmitAttachmentEventSucceeds(t *testing.T) { func TestSubmitAttachmentEventAttachmentExpired(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) attachmentEvent := eniAttachmentEventWithExpiry(attachmentARN, 100*time.Millisecond) @@ -250,7 +251,7 @@ func TestSubmitAttachmentEventAttachmentExpired(t *testing.T) { func TestSubmitAttachmentEventAttachmentIsSent(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) attachmentEvent := resAttachmentEvent(attachmentARN) attachmentEvent.Attachment.SetSentStatus() diff --git a/agent/eventhandler/handler.go b/agent/eventhandler/handler.go index 9241b41f904..9d1b4c5b1bf 100644 --- a/agent/eventhandler/handler.go +++ b/agent/eventhandler/handler.go @@ -17,15 +17,15 @@ import ( "context" "fmt" - "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/engine" "github.com/aws/amazon-ecs-agent/agent/statechange" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" "github.com/cihub/seelog" ) // HandleEngineEvents handles state change events from the state change event channel by sending it to // responsible event handler -func HandleEngineEvents(ctx context.Context, taskEngine engine.TaskEngine, client api.ECSClient, +func HandleEngineEvents(ctx context.Context, taskEngine engine.TaskEngine, client ecs.ECSClient, taskHandler *TaskHandler, attachmentEventHandler *AttachmentEventHandler) { for { @@ -51,7 +51,7 @@ func HandleEngineEvents(ctx context.Context, taskEngine engine.TaskEngine, clien } } -func handleEngineEvent(event statechange.Event, client api.ECSClient, taskHandler *TaskHandler, +func handleEngineEvent(event statechange.Event, client ecs.ECSClient, taskHandler *TaskHandler, attachmentEventHandler *AttachmentEventHandler) error { switch event.GetEventType() { case statechange.TaskEvent, statechange.ContainerEvent, statechange.ManagedAgentEvent: diff --git a/agent/eventhandler/handler_test.go b/agent/eventhandler/handler_test.go index cdc8636027c..4e3e9472792 100644 --- a/agent/eventhandler/handler_test.go +++ b/agent/eventhandler/handler_test.go @@ -21,10 +21,10 @@ import ( "sync" "testing" - "github.com/aws/amazon-ecs-agent/agent/api" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" ) @@ -33,7 +33,7 @@ func TestHandleEngineEvent(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) taskHandler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -53,14 +53,13 @@ func TestHandleEngineEvent(t *testing.T) { } assert.NoError(t, attachmentEvent.Attachment.StartTimer(timeoutFunc)) - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, 2, len(change.Containers)) - assert.Equal(t, taskARN, change.Containers[0].TaskArn) - assert.Equal(t, taskARN, change.Containers[1].TaskArn) + assert.Equal(t, taskARN, change.TaskARN) wg.Done() }) - client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Do(func(change api.AttachmentStateChange) { + client.EXPECT().SubmitAttachmentStateChange(gomock.Any()).Do(func(change ecs.AttachmentStateChange) { assert.NotNil(t, change.Attachment) assert.Equal(t, "attachmentARN", change.Attachment.GetAttachmentARN()) wg.Done() diff --git a/agent/eventhandler/task_handler.go b/agent/eventhandler/task_handler.go index c0fb6be24d5..28032f62b47 100644 --- a/agent/eventhandler/task_handler.go +++ b/agent/eventhandler/task_handler.go @@ -27,7 +27,8 @@ import ( "github.com/aws/amazon-ecs-agent/agent/metrics" "github.com/aws/amazon-ecs-agent/agent/statechange" "github.com/aws/amazon-ecs-agent/agent/utils" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + ecsmodel "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" @@ -79,7 +80,7 @@ type TaskHandler struct { maxDrainEventsFrequency time.Duration state dockerstate.TaskEngineState - client api.ECSClient + client ecs.ECSClient ctx context.Context } @@ -104,7 +105,7 @@ type taskSendableEvents struct { func NewTaskHandler(ctx context.Context, dataClient data.Client, state dockerstate.TaskEngineState, - client api.ECSClient) *TaskHandler { + client ecs.ECSClient) *TaskHandler { // Create a handler and start the periodic event drain loop taskHandler := &TaskHandler{ ctx: ctx, @@ -131,7 +132,7 @@ func NewTaskHandler(ctx context.Context, // If the event is for task state change, it triggers the non-blocking // handler.submitTaskEvents method to submit the batched container state // changes and the task state change to ECS -func (handler *TaskHandler) AddStateChangeEvent(change statechange.Event, client api.ECSClient) error { +func (handler *TaskHandler) AddStateChangeEvent(change statechange.Event, client ecs.ECSClient) error { handler.lock.Lock() defer handler.lock.Unlock() switch change.GetEventType() { @@ -273,7 +274,7 @@ func (handler *TaskHandler) batchManagedAgentEventUnsafe(event api.ManagedAgentS // flushBatchUnsafe attaches the task arn's container events to TaskStateChange event // by creating the sendable event list. It then submits this event to ECS asynchronously -func (handler *TaskHandler) flushBatchUnsafe(taskStateChange *api.TaskStateChange, client api.ECSClient) { +func (handler *TaskHandler) flushBatchUnsafe(taskStateChange *api.TaskStateChange, client ecs.ECSClient) { taskStateChange.Containers = append(taskStateChange.Containers, handler.tasksToContainerStates[taskStateChange.TaskARN]...) // All container events for the task have now been copied to the @@ -318,7 +319,7 @@ func (handler *TaskHandler) getTaskEventsUnsafe(event *sendableEvent) *taskSenda // Continuously retries sending an event until it succeeds, sleeping between each // attempt -func (handler *TaskHandler) submitTaskEvents(taskEvents *taskSendableEvents, client api.ECSClient, taskARN string) { +func (handler *TaskHandler) submitTaskEvents(taskEvents *taskSendableEvents, client ecs.ECSClient, taskARN string) { defer metrics.MetricsEngineGlobal.RecordECSClientMetric("SUBMIT_TASK_EVENTS")() defer handler.removeTaskEvents(taskARN) @@ -358,7 +359,7 @@ func (handler *TaskHandler) removeTaskEvents(taskARN string) { // the handler's submitTaskEvents async method to submit this change if // there's no go routines already sending changes for this event list func (taskEvents *taskSendableEvents) sendChange(change *sendableEvent, - client api.ECSClient, + client ecs.ECSClient, handler *TaskHandler) { taskEvents.lock.Lock() @@ -442,7 +443,7 @@ func (taskEvents *taskSendableEvents) toStringUnsafe() string { // handleInvalidParamException removes the event from event queue when its parameters are // invalid to reduce redundant API call func handleInvalidParamException(err error, events *list.List, eventToSubmit *list.Element) { - if utils.IsAWSErrorCodeEqual(err, ecs.ErrCodeInvalidParameterException) { + if utils.IsAWSErrorCodeEqual(err, ecsmodel.ErrCodeInvalidParameterException) { event := eventToSubmit.Value.(*sendableEvent) logger.Warn("TaskHandler: Event is sent with invalid parameters; just removing", event.toFields()) events.Remove(eventToSubmit) diff --git a/agent/eventhandler/task_handler_test.go b/agent/eventhandler/task_handler_test.go index 30e0f89d516..4427d2e46cc 100644 --- a/agent/eventhandler/task_handler_test.go +++ b/agent/eventhandler/task_handler_test.go @@ -26,7 +26,6 @@ import ( "github.com/aws/amazon-ecs-agent/agent/api" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" @@ -35,12 +34,15 @@ import ( "github.com/aws/amazon-ecs-agent/agent/utils" "github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" - "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" + ecsmodel "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" mock_retry "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry/mock" + "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/golang/mock/gomock" "github.com/pkg/errors" @@ -52,7 +54,7 @@ const taskARN = "taskarn" func TestSendsEventsOneContainer(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) handler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -66,10 +68,9 @@ func TestSendsEventsOneContainer(t *testing.T) { contEvent2 := containerEvent(taskARN) taskEvent2 := taskEvent(taskARN) - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, 2, len(change.Containers)) - assert.Equal(t, taskARN, change.Containers[0].TaskArn) - assert.Equal(t, taskARN, change.Containers[1].TaskArn) + assert.Equal(t, taskARN, change.TaskARN) wg.Done() }) @@ -83,7 +84,7 @@ func TestSendsEventsOneContainer(t *testing.T) { func TestSendsEventsOneEventRetries(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) handler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -108,7 +109,7 @@ func TestSendsEventsOneEventRetries(t *testing.T) { func TestSendsEventsInvalidParametersEventsRemoved(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) handler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -122,7 +123,7 @@ func TestSendsEventsInvalidParametersEventsRemoved(t *testing.T) { client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(interface{}) { assert.Equal(t, 1, handler.tasksToEvents[taskARN].events.Len()) wg.Done() - }).Return(awserr.New(ecs.ErrCodeInvalidParameterException, "", nil)) + }).Return(awserr.New(ecsmodel.ErrCodeInvalidParameterException, "", nil)) handler.AddStateChangeEvent(taskEvent, client) @@ -136,7 +137,7 @@ func TestSendsEventsInvalidParametersEventsRemoved(t *testing.T) { func TestSendsEventsConcurrentLimit(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) handler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -174,7 +175,7 @@ func TestSendsEventsConcurrentLimit(t *testing.T) { func TestSendsEventsContainerDifferences(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) handler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -188,12 +189,10 @@ func TestSendsEventsContainerDifferences(t *testing.T) { contEvent2 := containerEventStopped(taskARN) taskEvent := taskEvent(taskARN) - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { - assert.Equal(t, 2, len(change.Containers)) - assert.Equal(t, taskARN, change.Containers[0].TaskArn) - assert.Equal(t, apicontainerstatus.ContainerRunning, change.Containers[0].Status) - assert.Equal(t, taskARN, change.Containers[1].TaskArn) - assert.Equal(t, apicontainerstatus.ContainerStopped, change.Containers[1].Status) + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { + assert.Equal(t, taskARN, change.TaskARN) + assert.Equal(t, apicontainerstatus.ContainerRunning.String(), aws.StringValue(change.Containers[0].Status)) + assert.Equal(t, apicontainerstatus.ContainerStopped.String(), aws.StringValue(change.Containers[1].Status)) wg.Done() }) @@ -207,7 +206,7 @@ func TestSendsEventsContainerDifferences(t *testing.T) { func TestSendsEventsTaskDifferences(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) dataClient := data.NewNoopClient() ctx, cancel := context.WithCancel(context.Background()) @@ -231,13 +230,13 @@ func TestSendsEventsTaskDifferences(t *testing.T) { contEventB2 := containerEventStopped(taskARNB) taskEventB := taskEventStopped(taskARNB) - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, taskARNA, change.TaskARN) wgAddEvent.Done() wg.Done() }) - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, taskARNB, change.TaskARN) wg.Done() }) @@ -257,7 +256,7 @@ func TestSendsEventsTaskDifferences(t *testing.T) { func TestSendsEventsDedupe(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) handler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -284,9 +283,8 @@ func TestSendsEventsDedupe(t *testing.T) { cont2.(api.ContainerStateChange).Container.SetSentStatus(apicontainerstatus.ContainerRunning) // Expect to send a task status but not a container status - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, 1, len(change.Containers)) - assert.Equal(t, taskARNB, change.Containers[0].TaskArn) assert.Equal(t, taskARNB, change.TaskARN) wg.Done() }) @@ -303,7 +301,7 @@ func TestCleanupTaskEventAfterSubmit(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) handler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -319,7 +317,7 @@ func TestCleanupTaskEventAfterSubmit(t *testing.T) { taskEvent3 := taskEvent(taskARN2) client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do( - func(change api.TaskStateChange) { + func(change ecs.TaskStateChange) { wg.Done() }).Times(3) @@ -370,7 +368,7 @@ func taskEventStopped(arn string) statechange.Event { func TestENISentStatusChange(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) task := &apitask.Task{ Arn: taskARN, @@ -455,7 +453,7 @@ func TestSubmitTaskEventsWhenSubmittingTaskRunningAfterStopped(t *testing.T) { defer ctrl.Finish() state := mock_dockerstate.NewMockTaskEngineState(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) handler := &TaskHandler{ state: state, @@ -493,7 +491,7 @@ func TestSubmitTaskEventsWhenSubmittingTaskRunningAfterStopped(t *testing.T) { var wg sync.WaitGroup wg.Add(1) gomock.InOrder( - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, apitaskstatus.TaskStopped, change.Status) }), backoff.EXPECT().Reset().Do(func() { @@ -519,7 +517,7 @@ func TestSubmitTaskEventsWhenSubmittingTaskStoppedAfterRunning(t *testing.T) { defer ctrl.Finish() state := mock_dockerstate.NewMockTaskEngineState(ctrl) - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) handler := &TaskHandler{ state: state, @@ -557,7 +555,7 @@ func TestSubmitTaskEventsWhenSubmittingTaskStoppedAfterRunning(t *testing.T) { var wg sync.WaitGroup wg.Add(1) gomock.InOrder( - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, apitaskstatus.TaskRunning, change.Status) }), backoff.EXPECT().Reset().Do(func() { @@ -573,7 +571,7 @@ func TestSubmitTaskEventsWhenSubmittingTaskStoppedAfterRunning(t *testing.T) { wg.Add(1) gomock.InOrder( - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, apitaskstatus.TaskStopped, change.Status) }), backoff.EXPECT().Reset().Do(func() { @@ -591,7 +589,7 @@ func TestSubmitTaskEventsWhenSubmittingTaskStoppedAfterRunning(t *testing.T) { func TestSendContainerAndManagedAgentEvents(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) handler := NewTaskHandler(ctx, data.NewNoopClient(), dockerstate.NewTaskEngineState(), client) @@ -604,11 +602,10 @@ func TestSendContainerAndManagedAgentEvents(t *testing.T) { cEevent1 := containerEvent(taskARN) taskEvent1 := taskEvent(taskARN) - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, 1, len(change.ManagedAgents)) assert.Equal(t, 1, len(change.Containers)) - assert.Equal(t, taskARN, change.ManagedAgents[0].TaskArn) - assert.Equal(t, taskARN, change.Containers[0].TaskArn) + assert.Equal(t, taskARN, change.TaskARN) wg.Done() }) @@ -624,7 +621,7 @@ func TestSendContainerAndManagedAgentEvents(t *testing.T) { func TestSendManagedAgentEventsTaskDifferences(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - client := mock_api.NewMockECSClient(ctrl) + client := mock_ecs.NewMockECSClient(ctrl) dataClient := data.NewNoopClient() ctx, cancel := context.WithCancel(context.Background()) @@ -647,13 +644,13 @@ func TestSendManagedAgentEventsTaskDifferences(t *testing.T) { maEventB1 := managedAgentEvent(taskARNB) taskEventB := taskEventStopped(taskARNB) - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, taskARNA, change.TaskARN) wgAddEvent.Done() wg.Done() }) - client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change api.TaskStateChange) { + client.EXPECT().SubmitTaskStateChange(gomock.Any()).Do(func(change ecs.TaskStateChange) { assert.Equal(t, taskARNB, change.TaskARN) wg.Done() }) diff --git a/agent/eventhandler/task_handler_types.go b/agent/eventhandler/task_handler_types.go index 431e7d2f6bb..db4b3d45717 100644 --- a/agent/eventhandler/task_handler_types.go +++ b/agent/eventhandler/task_handler_types.go @@ -17,15 +17,15 @@ import ( "container/list" "sync" - "github.com/aws/amazon-ecs-agent/ecs-agent/logger" - "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" - "github.com/aws/amazon-ecs-agent/agent/api" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/data" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" "github.com/cihub/seelog" ) @@ -150,7 +150,7 @@ func (event *sendableEvent) send( sendStatusToECS sendStatusChangeToECS, setChangeSent setStatusSent, eventType string, - client api.ECSClient, + client ecs.ECSClient, eventToSubmit *list.Element, dataClient data.Client, backoff retry.Backoff, @@ -175,18 +175,32 @@ func (event *sendableEvent) send( } // sendStatusChangeToECS defines a function type for invoking the appropriate ECS state change API -type sendStatusChangeToECS func(client api.ECSClient, event *sendableEvent) error +type sendStatusChangeToECS func(client ecs.ECSClient, event *sendableEvent) error // sendContainerStatusToECS invokes the SubmitContainerStateChange API to send a // container status change to ECS -func sendContainerStatusToECS(client api.ECSClient, event *sendableEvent) error { - return client.SubmitContainerStateChange(event.containerChange) +func sendContainerStatusToECS(client ecs.ECSClient, event *sendableEvent) error { + containerStateChange, err := event.containerChange.ToECSAgent() + if err != nil { + return err + } + + // containerStateChange and err both nil in the case of an unsupported upstream container state. + // No-op (i.e., don't submit container state change) in this case. + if containerStateChange == nil { + return nil + } + return client.SubmitContainerStateChange(*containerStateChange) } // sendTaskStatusToECS invokes the SubmitTaskStateChange API to send a task // status change to ECS -func sendTaskStatusToECS(client api.ECSClient, event *sendableEvent) error { - return client.SubmitTaskStateChange(event.taskChange) +func sendTaskStatusToECS(client ecs.ECSClient, event *sendableEvent) error { + taskStateChange, err := event.taskChange.ToECSAgent() + if err != nil { + return err + } + return client.SubmitTaskStateChange(*taskStateChange) } // setStatusSent defines a function type to mark the event as sent diff --git a/agent/handlers/agentapi/taskprotection/factory.go b/agent/handlers/agentapi/taskprotection/factory.go index b9bbca48e2f..15c278131f2 100644 --- a/agent/handlers/agentapi/taskprotection/factory.go +++ b/agent/handlers/agentapi/taskprotection/factory.go @@ -13,14 +13,13 @@ package taskprotection import ( - "github.com/aws/amazon-ecs-agent/agent/api/ecsclient" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/version" - "github.com/aws/amazon-ecs-agent/ecs-agent/httpclient" - ecsapi "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + ecsclient "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/httpclient" "github.com/aws/aws-sdk-go/aws" awscreds "github.com/aws/aws-sdk-go/aws/credentials" diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index ea0dedc792f..65f47114814 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -18,7 +18,6 @@ import ( "net/http" "time" - "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" tpfactory "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection" @@ -27,6 +26,7 @@ import ( v4 "github.com/aws/amazon-ecs-agent/agent/handlers/v4" "github.com/aws/amazon-ecs-agent/agent/logger/audit" "github.com/aws/amazon-ecs-agent/agent/stats" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" auditinterface "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" @@ -57,7 +57,7 @@ func taskServerSetup( credentialsManager credentials.Manager, auditLogger auditinterface.AuditLogger, state dockerstate.TaskEngineState, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, cluster string, statsEngine stats.Engine, steadyStateRate int, @@ -102,7 +102,7 @@ func taskServerSetup( // v2HandlersSetup adds all handlers in v2 package to the mux router. func v2HandlersSetup(muxRouter *mux.Router, state dockerstate.TaskEngineState, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, statsEngine stats.Engine, cluster string, credentialsManager credentials.Manager, @@ -123,7 +123,7 @@ func v2HandlersSetup(muxRouter *mux.Router, // v3HandlersSetup adds all handlers in v3 package to the mux router. func v3HandlersSetup(muxRouter *mux.Router, state dockerstate.TaskEngineState, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, statsEngine stats.Engine, cluster string, availabilityZone string, @@ -141,7 +141,7 @@ func v3HandlersSetup(muxRouter *mux.Router, // v4HandlerSetup adda all handlers in v4 package to the mux router func v4HandlersSetup(muxRouter *mux.Router, state dockerstate.TaskEngineState, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, statsEngine stats.Engine, cluster string, availabilityZone string, @@ -190,7 +190,7 @@ func ServeTaskHTTPEndpoint( ctx context.Context, credentialsManager credentials.Manager, state dockerstate.TaskEngineState, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, containerInstanceArn string, cfg *config.Config, statsEngine stats.Engine, diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 203a6240140..f01bd9140b0 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -30,14 +30,13 @@ import ( "time" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/config" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" - mock_taskprotection "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" @@ -781,7 +780,7 @@ func testErrorResponsesFromServer(t *testing.T, path string, expectedErrorMessag credentialsManager := mock_credentials.NewMockManager(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) @@ -818,7 +817,7 @@ func getResponseForCredentialsRequest(t *testing.T, expectedStatus int, defer ctrl.Finish() credentialsManager := mock_credentials.NewMockManager(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) @@ -870,7 +869,7 @@ func TestV3ContainerAssociations(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) gomock.InOrder( state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), @@ -904,7 +903,7 @@ func TestV3ContainerAssociation(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), @@ -933,7 +932,7 @@ func TestV4ContainerAssociations(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) gomock.InOrder( state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), @@ -967,7 +966,7 @@ func TestV4ContainerAssociation(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), @@ -998,7 +997,7 @@ func TestTaskHTTPEndpoint301Redirect(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, @@ -1041,7 +1040,7 @@ func TestTaskHTTPEndpointErrorCode404(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, @@ -1081,7 +1080,7 @@ func TestTaskHTTPEndpointErrorCode400(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, @@ -1120,7 +1119,7 @@ func TestTaskHTTPEndpointErrorCode500(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, @@ -1190,7 +1189,7 @@ func TestV4TaskNotFoundError404(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, @@ -1246,7 +1245,7 @@ func TestV4Unexpected500Error(t *testing.T) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, @@ -1304,7 +1303,7 @@ type TMDSTestCase[R TMDSResponse] struct { // Function to set expectations on mock stats engine setStatsEngineExpectations func(engine *mock_stats.MockEngine) // Function to set expectations on mock ECS Client - setECSClientExpectations func(ecsClient *mock_api.MockECSClient) + setECSClientExpectations func(ecsClient *mock_ecs.MockECSClient) // Function to set expectations on mock Task Protection Client Factory setTaskProtectionClientFactoryExpectations func( ctrl *gomock.Controller, factory *tp.MockTaskProtectionClientFactoryInterface) @@ -1332,7 +1331,7 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { state := mock_dockerstate.NewMockTaskEngineState(ctrl) auditLog := mock_audit.NewMockAuditLogger(ctrl) statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) credsManager := mock_credentials.NewMockManager(ctrl) taskProtectionClientFactory := tp.NewMockTaskProtectionClientFactoryInterface(ctrl) @@ -2042,7 +2041,7 @@ func TestV2TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v2.TaskResponse]{ path: path, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn). Return(ecsInstanceTags, nil), @@ -2061,7 +2060,7 @@ func TestV2TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v2.TaskResponse]{ path: v2BaseMetadataWithTagsPath, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(ecsInstanceTags, nil), ecsClient.EXPECT().GetResourceTags(taskARN).Return(nil, errors.New("error")), @@ -2077,7 +2076,7 @@ func TestV2TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v2.TaskResponse]{ path: v2BaseMetadataWithTagsPath, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(nil, errors.New("error")), ecsClient.EXPECT().GetResourceTags(taskARN).Return(ecsTaskTags, nil), @@ -2091,7 +2090,7 @@ func TestV2TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v2.TaskResponse]{ path: v2BaseMetadataWithTagsPath, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(nil, errors.New("error")), ecsClient.EXPECT().GetResourceTags(taskARN).Return(nil, errors.New("error")), @@ -2157,7 +2156,7 @@ func TestV3TaskMetadataWithTags(t *testing.T) { path := v3BasePath + v3EndpointID + "/taskWithTags" - happyECSClientExpectations := func(ecsClient *mock_api.MockECSClient) { + happyECSClientExpectations := func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(ecsInstanceTags, nil), ecsClient.EXPECT().GetResourceTags(taskARN).Return(ecsTaskTags, nil), @@ -2190,7 +2189,7 @@ func TestV3TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v2.TaskResponse]{ path: path, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(nil, errors.New("error")), ecsClient.EXPECT().GetResourceTags(taskARN).Return(ecsTaskTags, nil), @@ -2206,7 +2205,7 @@ func TestV3TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v2.TaskResponse]{ path: path, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(ecsInstanceTags, nil), ecsClient.EXPECT().GetResourceTags(taskARN).Return(nil, errors.New("error")), @@ -2220,7 +2219,7 @@ func TestV3TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v2.TaskResponse]{ path: path, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(nil, errors.New("error")), ecsClient.EXPECT().GetResourceTags(taskARN).Return(nil, errors.New("error")), @@ -2331,7 +2330,7 @@ func TestV4TaskMetadataWithTags(t *testing.T) { ResourceARN: taskARN, } - happyECSClientExpectations := func(ecsClient *mock_api.MockECSClient) { + happyECSClientExpectations := func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(ecsInstanceTags, nil), ecsClient.EXPECT().GetResourceTags(taskARN).Return(ecsTaskTags, nil), @@ -2368,7 +2367,7 @@ func TestV4TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v4.TaskResponse]{ path: path, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(nil, errors.New("error")), ecsClient.EXPECT().GetResourceTags(taskARN).Return(ecsTaskTags, nil), @@ -2385,7 +2384,7 @@ func TestV4TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v4.TaskResponse]{ path: path, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(ecsInstanceTags, nil), ecsClient.EXPECT().GetResourceTags(taskARN).Return(nil, errors.New("error")), @@ -2402,7 +2401,7 @@ func TestV4TaskMetadataWithTags(t *testing.T) { testTMDSRequest(t, TMDSTestCase[v4.TaskResponse]{ path: path, setStateExpectations: happyStateExpectations, - setECSClientExpectations: func(ecsClient *mock_api.MockECSClient) { + setECSClientExpectations: func(ecsClient *mock_ecs.MockECSClient) { gomock.InOrder( ecsClient.EXPECT().GetResourceTags(containerInstanceArn).Return(nil, errors.New("error")), ecsClient.EXPECT().GetResourceTags(taskARN).Return(nil, errors.New("error")), @@ -3018,7 +3017,7 @@ func TestGetTaskProtection(t *testing.T) { ctrl *gomock.Controller, factory *tp.MockTaskProtectionClientFactoryInterface, ) { - client := mock_taskprotection.NewMockECSTaskProtectionSDK(ctrl) + client := mock_ecs.NewMockECSTaskProtectionSDK(ctrl) client.EXPECT().GetTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err) factory.EXPECT().NewTaskProtectionClient(taskRoleCredentials()).Return(client) } @@ -3289,7 +3288,7 @@ func TestUpdateTaskProtection(t *testing.T) { ctrl *gomock.Controller, factory *tp.MockTaskProtectionClientFactoryInterface, ) { - client := mock_taskprotection.NewMockECSTaskProtectionSDK(ctrl) + client := mock_ecs.NewMockECSTaskProtectionSDK(ctrl) client.EXPECT().UpdateTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err) factory.EXPECT().NewTaskProtectionClient(taskRoleCredentials()).Return(client) } diff --git a/agent/handlers/v2/response.go b/agent/handlers/v2/response.go index e1fef4e27c5..52882fca18f 100644 --- a/agent/handlers/v2/response.go +++ b/agent/handlers/v2/response.go @@ -16,11 +16,11 @@ package v2 import ( "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/amazon-ecs-agent/agent/api" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" v1 "github.com/aws/amazon-ecs-agent/agent/handlers/v1" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" @@ -38,7 +38,7 @@ const minimumCPUUnit = 2 func NewTaskResponse( taskARN string, state dockerstate.TaskEngineState, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, cluster string, az string, containerInstanceArn string, @@ -105,7 +105,7 @@ func NewTaskResponse( } // propagateTagsToMetadata retrieves container instance and task tags from ECS -func propagateTagsToMetadata(ecsClient api.ECSClient, containerInstanceARN, taskARN string, resp *tmdsv2.TaskResponse, includeV4Metadata bool) { +func propagateTagsToMetadata(ecsClient ecs.ECSClient, containerInstanceARN, taskARN string, resp *tmdsv2.TaskResponse, includeV4Metadata bool) { containerInstanceTags, err := ecsClient.GetResourceTags(containerInstanceARN) if err == nil { diff --git a/agent/handlers/v2/response_test.go b/agent/handlers/v2/response_test.go index 37fc68c5c37..e92b7c0273f 100644 --- a/agent/handlers/v2/response_test.go +++ b/agent/handlers/v2/response_test.go @@ -23,10 +23,10 @@ import ( "time" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" @@ -65,7 +65,7 @@ func TestTaskResponse(t *testing.T) { defer ctrl.Finish() state := mock_dockerstate.NewMockTaskEngineState(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) now := time.Now() task := &apitask.Task{ Arn: taskARN, @@ -161,7 +161,7 @@ func TestTaskResponseWithV4Metadata(t *testing.T) { defer ctrl.Finish() state := mock_dockerstate.NewMockTaskEngineState(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) now := time.Now() task := &apitask.Task{ Arn: taskARN, @@ -395,7 +395,7 @@ func TestTaskResponseMarshal(t *testing.T) { } state := mock_dockerstate.NewMockTaskEngineState(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) task := &apitask.Task{ Arn: taskARN, @@ -616,7 +616,7 @@ func TestTaskResponseWithV4TagsError(t *testing.T) { defer ctrl.Finish() state := mock_dockerstate.NewMockTaskEngineState(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) now := time.Now() task := &apitask.Task{ Arn: taskARN, diff --git a/agent/handlers/v2/task_container_metadata_handler.go b/agent/handlers/v2/task_container_metadata_handler.go index 5a72c06a2a2..171a8d4c844 100644 --- a/agent/handlers/v2/task_container_metadata_handler.go +++ b/agent/handlers/v2/task_container_metadata_handler.go @@ -18,8 +18,8 @@ import ( "fmt" "net/http" - "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" "github.com/cihub/seelog" ) @@ -46,7 +46,7 @@ const ( var ContainerMetadataPath = TaskMetadataPathWithSlash + utils.ConstructMuxVar(metadataContainerIDMuxName, utils.AnythingButEmptyRegEx) // TaskContainerMetadataHandler returns the handler method for handling task and container metadata requests. -func TaskContainerMetadataHandler(state dockerstate.TaskEngineState, ecsClient api.ECSClient, cluster, az, containerInstanceArn string, propagateTags bool) func(http.ResponseWriter, *http.Request) { +func TaskContainerMetadataHandler(state dockerstate.TaskEngineState, ecsClient ecs.ECSClient, cluster, az, containerInstanceArn string, propagateTags bool) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { taskARN, err := getTaskARNByRequest(r, state) if err != nil { @@ -89,7 +89,7 @@ func WriteContainerMetadataResponse(w http.ResponseWriter, containerID string, s } // WriteTaskMetadataResponse writes the task metadata to response writer. -func WriteTaskMetadataResponse(w http.ResponseWriter, taskARN string, cluster string, state dockerstate.TaskEngineState, ecsClient api.ECSClient, az, containerInstanceArn string, propagateTags bool) { +func WriteTaskMetadataResponse(w http.ResponseWriter, taskARN string, cluster string, state dockerstate.TaskEngineState, ecsClient ecs.ECSClient, az, containerInstanceArn string, propagateTags bool) { // Generate a response for the task taskResponse, err := NewTaskResponse(taskARN, state, ecsClient, cluster, az, containerInstanceArn, propagateTags, false) if err != nil { diff --git a/agent/handlers/v3/task_metadata_handler.go b/agent/handlers/v3/task_metadata_handler.go index 61ebb608ca7..e43a6c83664 100644 --- a/agent/handlers/v3/task_metadata_handler.go +++ b/agent/handlers/v3/task_metadata_handler.go @@ -18,9 +18,9 @@ import ( "fmt" "net/http" - "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" v2 "github.com/aws/amazon-ecs-agent/agent/handlers/v2" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" tmdsv2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" "github.com/cihub/seelog" @@ -37,7 +37,7 @@ var TaskMetadataPath = "/v3/" + utils.ConstructMuxVar(V3EndpointIDMuxName, utils var TaskWithTagsMetadataPath = "/v3/" + utils.ConstructMuxVar(V3EndpointIDMuxName, utils.AnythingButSlashRegEx) + "/taskWithTags" // TaskMetadataHandler returns the handler method for handling task metadata requests. -func TaskMetadataHandler(state dockerstate.TaskEngineState, ecsClient api.ECSClient, cluster, az, containerInstanceArn string, propagateTags bool) func(http.ResponseWriter, *http.Request) { +func TaskMetadataHandler(state dockerstate.TaskEngineState, ecsClient ecs.ECSClient, cluster, az, containerInstanceArn string, propagateTags bool) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { taskARN, err := GetTaskARNByRequest(r, state) if err != nil { diff --git a/agent/handlers/v4/response.go b/agent/handlers/v4/response.go index 4ad64387472..e01eabf03de 100644 --- a/agent/handlers/v4/response.go +++ b/agent/handlers/v4/response.go @@ -14,11 +14,11 @@ package v4 import ( - "github.com/aws/amazon-ecs-agent/agent/api" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" v2 "github.com/aws/amazon-ecs-agent/agent/handlers/v2" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" @@ -32,7 +32,7 @@ import ( func NewTaskResponse( taskARN string, state dockerstate.TaskEngineState, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, cluster string, az string, vpcID string, diff --git a/agent/handlers/v4/response_test.go b/agent/handlers/v4/response_test.go index 544e4382257..32e6b57c144 100644 --- a/agent/handlers/v4/response_test.go +++ b/agent/handlers/v4/response_test.go @@ -22,10 +22,10 @@ import ( "time" apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container" - mock_api "github.com/aws/amazon-ecs-agent/agent/api/mocks" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" ni "github.com/aws/amazon-ecs-agent/ecs-agent/netlib/model/networkinterface" @@ -65,7 +65,7 @@ func TestNewTaskContainerResponses(t *testing.T) { defer ctrl.Finish() state := mock_dockerstate.NewMockTaskEngineState(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) now := time.Now() task := &apitask.Task{ Arn: taskARN, diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index 27947dfd340..3e8e098dc27 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -15,9 +15,9 @@ package v4 import ( "fmt" - "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" "github.com/aws/amazon-ecs-agent/agent/stats" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" @@ -27,7 +27,7 @@ import ( type TMDSAgentState struct { state dockerstate.TaskEngineState statsEngine stats.Engine - ecsClient api.ECSClient + ecsClient ecs.ECSClient cluster string availabilityZone string vpcID string @@ -37,7 +37,7 @@ type TMDSAgentState struct { func NewTMDSAgentState( state dockerstate.TaskEngineState, statsEngine stats.Engine, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, cluster string, availabilityZone string, vpcID string, diff --git a/agent/stats/reporter/reporter.go b/agent/stats/reporter/reporter.go index a6b917951a3..d443378225a 100644 --- a/agent/stats/reporter/reporter.go +++ b/agent/stats/reporter/reporter.go @@ -18,10 +18,10 @@ import ( "errors" "time" - "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/engine" "github.com/aws/amazon-ecs-agent/agent/version" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/eventstream" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -50,7 +50,7 @@ func NewDockerTelemetrySession( credentialProvider *credentials.Credentials, cfg *config.Config, deregisterInstanceEventStream *eventstream.EventStream, - ecsClient api.ECSClient, + ecsClient ecs.ECSClient, taskEngine engine.TaskEngine, metricsChannel <-chan ecstcs.TelemetryMessage, healthChannel <-chan ecstcs.HealthMessage, diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go index cfd8299bc2c..7cc815321ca 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/session.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "github.com/aws/amazon-ecs-agent/ecs-agent/api" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -72,7 +72,7 @@ type session struct { containerInstanceARN string cluster string credentialsProvider *credentials.Credentials - discoverEndpointClient api.ECSDiscoverEndpointSDK + ecsClient ecs.ECSClient inactiveInstanceCB func() agentVersion string agentHash string @@ -104,7 +104,7 @@ type session struct { // NewSession creates a new Session. func NewSession(containerInstanceARN string, cluster string, - discoverEndpointClient api.ECSDiscoverEndpointSDK, + ecsClient ecs.ECSClient, credentialsProvider *credentials.Credentials, inactiveInstanceCB func(), clientFactory wsclient.ClientFactory, @@ -130,7 +130,7 @@ func NewSession(containerInstanceARN string, return &session{ containerInstanceARN: containerInstanceARN, cluster: cluster, - discoverEndpointClient: discoverEndpointClient, + ecsClient: ecsClient, credentialsProvider: credentialsProvider, inactiveInstanceCB: inactiveInstanceCB, clientFactory: clientFactory, @@ -220,7 +220,7 @@ func (s *session) Start(ctx context.Context) error { // startSessionOnce creates a session with ACS and handles requests using the passed // in arguments. func (s *session) startSessionOnce(ctx context.Context) error { - acsEndpoint, err := s.discoverEndpointClient.DiscoverPollEndpoint(s.containerInstanceARN) + acsEndpoint, err := s.ecsClient.DiscoverPollEndpoint(s.containerInstanceARN) if err != nil { logger.Error("ACS: Unable to discover poll endpoint", logger.Fields{ field.Error: err, diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go new file mode 100644 index 00000000000..978304d57cf --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go @@ -0,0 +1,801 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package ecsclient + +import ( + "context" + "errors" + "fmt" + "net/http" + "runtime" + "strings" + "time" + + apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + ecsmodel "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" + apierrors "github.com/aws/amazon-ecs-agent/ecs-agent/api/errors" + "github.com/aws/amazon-ecs-agent/ecs-agent/async" + "github.com/aws/amazon-ecs-agent/ecs-agent/config" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/instancecreds" + "github.com/aws/amazon-ecs-agent/ecs-agent/ec2" + "github.com/aws/amazon-ecs-agent/ecs-agent/httpclient" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils" + "github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/docker/docker/pkg/meminfo" +) + +const ( + ecsMaxImageDigestLength = 255 + ecsMaxContainerReasonLength = 255 + ecsMaxTaskReasonLength = 1024 + ecsMaxRuntimeIDLength = 255 + defaultPollEndpointCacheTTL = 12 * time.Hour + azAttrName = "ecs.availability-zone" + cpuArchAttrName = "ecs.cpu-architecture" + osTypeAttrName = "ecs.os-type" + osFamilyAttrName = "ecs.os-family" + // RoundtripTimeout should only time out after dial and TLS handshake timeouts have elapsed. + // Add additional 2 seconds to the sum of these 2 timeouts to be extra sure of this. + RoundtripTimeout = httpclient.DefaultDialTimeout + httpclient.DefaultTLSHandshakeTimeout + 2*time.Second + // Below constants are used for SetInstanceIdentity retry with exponential backoff. + setInstanceIdRetryTimeOut = 30 * time.Second + setInstanceIdRetryBackoffMin = 100 * time.Millisecond + setInstanceIdRetryBackoffMax = 5 * time.Second + setInstanceIdRetryBackoffJitter = 0.2 + setInstanceIdRetryBackoffMultiple = 2 +) + +// ecsClient implements ECSClient interface. +type ecsClient struct { + credentialsProvider *credentials.Credentials + configAccessor config.AgentConfigAccessor + standardClient ecs.ECSStandardSDK + submitStateChangeClient ecs.ECSSubmitStateSDK + ec2metadata ec2.EC2MetadataClient + httpClient *http.Client + pollEndpointCache async.TTLCache + isFIPSDetected bool + shouldExcludeIPv6PortBinding bool + sascCustomRetryBackoff func(func() error) error + stscAttachmentCustomRetryBackoff func(func() error) error +} + +// NewECSClient creates a new ECSClient interface object. +func NewECSClient( + credentialsProvider *credentials.Credentials, + configAccessor config.AgentConfigAccessor, + ec2MetadataClient ec2.EC2MetadataClient, + agentVer string, + options ...ECSClientOption) (ecs.ECSClient, error) { + + client := &ecsClient{ + credentialsProvider: credentialsProvider, + configAccessor: configAccessor, + ec2metadata: ec2MetadataClient, + httpClient: httpclient.New(RoundtripTimeout, configAccessor.AcceptInsecureCert(), agentVer, configAccessor.OSType()), + pollEndpointCache: async.NewTTLCache(&async.TTL{Duration: defaultPollEndpointCacheTTL}), + } + + // Apply options to configure/override ECS client values. + for _, opt := range options { + opt(client) + } + + ecsConfig := newECSConfig(credentialsProvider, configAccessor, client.httpClient, client.isFIPSDetected) + s, err := session.NewSession(&ecsConfig) + if err != nil { + return nil, err + } + + if client.standardClient == nil { + client.standardClient = ecsmodel.New(s) + } + if client.submitStateChangeClient == nil { + client.submitStateChangeClient = newSubmitStateChangeClient(&ecsConfig) + } + + return client, nil +} + +func newECSConfig( + credentialsProvider *credentials.Credentials, + configAccessor config.AgentConfigAccessor, + httpClient *http.Client, + isFIPSEnabled bool) aws.Config { + var ecsConfig aws.Config + ecsConfig.HTTPClient = httpClient + ecsConfig.Credentials = credentialsProvider + ecsConfig.Region = aws.String(configAccessor.AWSRegion()) + // We should respect the endpoint given (if any) because it could be the Gamma or Zeta endpoint of ECS service which + // don't have the corresponding FIPS endpoints. Otherwise, when the host has FIPS enabled, we should tell SDK to + // pick the FIPS endpoint. + if configAccessor.APIEndpoint() != "" { + ecsConfig.Endpoint = aws.String(configAccessor.APIEndpoint()) + } else if isFIPSEnabled { + ecsConfig.UseFIPSEndpoint = endpoints.FIPSEndpointStateEnabled + } + return ecsConfig +} + +// CreateCluster creates a cluster from a given name and returns its ARN. +func (client *ecsClient) CreateCluster(clusterName string) (string, error) { + resp, err := client.standardClient.CreateCluster(&ecsmodel.CreateClusterInput{ClusterName: &clusterName}) + if err != nil { + logger.Critical("Could not create cluster", logger.Fields{ + field.Cluster: clusterName, + field.Error: err, + }) + return "", err + } + logger.Info("Successfully created a cluster", logger.Fields{ + field.Cluster: clusterName, + }) + return *resp.Cluster.ClusterName, nil +} + +// RegisterContainerInstance calculates the appropriate resources, creates +// the default cluster if necessary, and returns the registered +// ContainerInstanceARN if successful. Supplying a non-empty container +// instance ARN allows a container instance to update its registered +// resources. +func (client *ecsClient) RegisterContainerInstance(containerInstanceArn string, attributes []*ecsmodel.Attribute, + tags []*ecsmodel.Tag, registrationToken string, platformDevices []*ecsmodel.PlatformDevice, + outpostARN string) (string, string, error) { + + clusterRef := client.configAccessor.Cluster() + // If our clusterRef is empty, we should try to create the default. + if clusterRef == "" { + clusterRef = client.configAccessor.DefaultClusterName() + defer client.configAccessor.UpdateCluster(clusterRef) + // Attempt to register without checking existence of the cluster so that we don't require + // excess permissions in the case where the cluster already exists and is active. + containerInstanceArn, availabilityzone, err := client.registerContainerInstance(clusterRef, + containerInstanceArn, attributes, tags, registrationToken, platformDevices, outpostARN) + if err == nil { + return containerInstanceArn, availabilityzone, nil + } + + // If trying to register fails because the default cluster doesn't exist, try to create the cluster before + // calling register again. + if apierrors.IsClusterNotFoundError(err) { + clusterRef, err = client.CreateCluster(clusterRef) + if err != nil { + return "", "", err + } + } + } + return client.registerContainerInstance(clusterRef, containerInstanceArn, attributes, tags, registrationToken, + platformDevices, outpostARN) +} + +func (client *ecsClient) registerContainerInstance(clusterRef string, containerInstanceArn string, + attributes []*ecsmodel.Attribute, tags []*ecsmodel.Tag, registrationToken string, + platformDevices []*ecsmodel.PlatformDevice, outpostARN string) (string, string, error) { + + registerRequest := ecsmodel.RegisterContainerInstanceInput{Cluster: &clusterRef} + var registrationAttributes []*ecsmodel.Attribute + if containerInstanceArn != "" { + // We are re-connecting a previously registered instance, restored from snapshot. + registerRequest.ContainerInstanceArn = &containerInstanceArn + } else { + // This is a new instance, not previously registered. + // Custom attribute registration only happens on initial instance registration. + for _, attribute := range client.getCustomAttributes() { + logger.Debug("Added a new custom attribute", logger.Fields{ + field.AttributeName: aws.StringValue(attribute.Name), + field.AttributeValue: aws.StringValue(attribute.Value), + }) + registrationAttributes = append(registrationAttributes, attribute) + } + } + // Standard attributes are included with all registrations. + registrationAttributes = append(registrationAttributes, attributes...) + + // Add additional attributes, such as the OS type. + registrationAttributes = append(registrationAttributes, client.getAdditionalAttributes()...) + registrationAttributes = append(registrationAttributes, client.getOutpostAttribute(outpostARN)...) + + registerRequest.Attributes = registrationAttributes + if len(tags) > 0 { + registerRequest.Tags = tags + } + registerRequest.PlatformDevices = platformDevices + registerRequest = client.setInstanceIdentity(registerRequest) + + resources, err := client.getResources() + if err != nil { + return "", "", err + } + + registerRequest.TotalResources = resources + + registerRequest.ClientToken = ®istrationToken + resp, err := client.standardClient.RegisterContainerInstance(®isterRequest) + if err != nil { + logger.Error("Unable to register as a container instance with ECS", logger.Fields{ + field.Error: err, + }) + return "", "", err + } + + var availabilityzone = "" + if resp != nil { + for _, attr := range resp.ContainerInstance.Attributes { + if aws.StringValue(attr.Name) == azAttrName { + availabilityzone = aws.StringValue(attr.Value) + break + } + } + } + + logger.Info("Registered container instance with cluster!") + err = validateRegisteredAttributes(registerRequest.Attributes, resp.ContainerInstance.Attributes) + return aws.StringValue(resp.ContainerInstance.ContainerInstanceArn), availabilityzone, err +} + +func (client *ecsClient) setInstanceIdentity( + registerRequest ecsmodel.RegisterContainerInstanceInput) ecsmodel.RegisterContainerInstanceInput { + instanceIdentityDoc := "" + instanceIdentitySignature := "" + + if client.configAccessor.NoInstanceIdentityDocument() { + logger.Info("Fetching Instance ID Document has been disabled") + registerRequest.InstanceIdentityDocument = &instanceIdentityDoc + registerRequest.InstanceIdentityDocumentSignature = &instanceIdentitySignature + return registerRequest + } + + iidRetrieved := true + backoff := retry.NewExponentialBackoff(setInstanceIdRetryBackoffMin, setInstanceIdRetryBackoffMax, + setInstanceIdRetryBackoffJitter, setInstanceIdRetryBackoffMultiple) + ctx, cancel := context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut) + defer cancel() + err := retry.RetryWithBackoffCtx(ctx, backoff, func() error { + var attemptErr error + logger.Debug("Attempting to get Instance Identity Document") + instanceIdentityDoc, attemptErr = client.ec2metadata.GetDynamicData(ec2.InstanceIdentityDocumentResource) + if attemptErr != nil { + logger.Debug("Unable to get instance identity document, retrying", logger.Fields{ + field.Error: attemptErr, + }) + // Force credentials to expire in case they are stale but not expired. + client.credentialsProvider.Expire() + client.credentialsProvider = instancecreds.GetCredentials(client.configAccessor.External()) + return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr) + } + logger.Debug("Successfully retrieved Instance Identity Document") + return nil + }) + if err != nil { + logger.Error("Unable to get instance identity document", logger.Fields{ + field.Error: err, + }) + iidRetrieved = false + } + registerRequest.InstanceIdentityDocument = &instanceIdentityDoc + + if iidRetrieved { + instanceIdentitySignature, err = client.ec2metadata. + GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + if err != nil { + logger.Error("Unable to get instance identity signature", logger.Fields{ + field.Error: err, + }) + } + } + + registerRequest.InstanceIdentityDocumentSignature = &instanceIdentitySignature + return registerRequest +} + +func attributesToMap(attributes []*ecsmodel.Attribute) map[string]string { + attributeMap := make(map[string]string) + attribs := attributes + for _, attribute := range attribs { + attributeMap[aws.StringValue(attribute.Name)] = aws.StringValue(attribute.Value) + } + return attributeMap +} + +func findMissingAttributes(expectedAttributes, actualAttributes map[string]string) ([]string, error) { + missingAttributes := make([]string, 0) + var err error + for key, val := range expectedAttributes { + if actualAttributes[key] != val { + missingAttributes = append(missingAttributes, key) + } else { + logger.Trace("Response contained expected value for attribute", logger.Fields{ + "key": key, + }) + } + } + if len(missingAttributes) > 0 { + err = apierrors.NewAttributeError("Attribute validation failed") + } + return missingAttributes, err +} + +func (client *ecsClient) getResources() ([]*ecsmodel.Resource, error) { + // Below are micro-optimizations - the pointers to integerStr and stringSetStr are used multiple times below. + integerStr := "INTEGER" + stringSetStr := "STRINGSET" + + cpu, mem := getCpuAndMemory() + remainingMem := mem - int64(client.configAccessor.ReservedMemory()) + logger.Info("Remaining memory", logger.Fields{ + "remainingMemory": remainingMem, + }) + if remainingMem < 0 { + return nil, fmt.Errorf( + "api register-container-instance: reserved memory is higher than available memory on the host, "+ + "total memory: %d, reserved: %d", mem, client.configAccessor.ReservedMemory()) + } + + cpuResource := ecsmodel.Resource{ + Name: aws.String("CPU"), + Type: &integerStr, + IntegerValue: &cpu, + } + memResource := ecsmodel.Resource{ + Name: aws.String("MEMORY"), + Type: &integerStr, + IntegerValue: &remainingMem, + } + portResource := ecsmodel.Resource{ + Name: aws.String("PORTS"), + Type: &stringSetStr, + StringSetValue: utils.Uint16SliceToStringSlice(client.configAccessor.ReservedPorts()), + } + udpPortResource := ecsmodel.Resource{ + Name: aws.String("PORTS_UDP"), + Type: &stringSetStr, + StringSetValue: utils.Uint16SliceToStringSlice(client.configAccessor.ReservedPortsUDP()), + } + + return []*ecsmodel.Resource{&cpuResource, &memResource, &portResource, &udpPortResource}, nil +} + +// GetHostResources calling getHostResources to get a list of CPU, MEMORY, PORTS and PORTS_UPD resources +// and return a resourceMap that map the resource name to each resource +func (client *ecsClient) GetHostResources() (map[string]*ecsmodel.Resource, error) { + resources, err := client.getResources() + if err != nil { + return nil, err + } + resourceMap := make(map[string]*ecsmodel.Resource) + for _, resource := range resources { + if *resource.Name == "PORTS" { + // Except for RCI, TCP Ports are named as PORTS_TCP in Agent for Host Resources purpose. + resource.Name = aws.String("PORTS_TCP") + } + resourceMap[*resource.Name] = resource + } + return resourceMap, nil +} + +func getCpuAndMemory() (int64, int64) { + memInfo, err := meminfo.Read() + mem := int64(0) + if err == nil { + mem = memInfo.MemTotal / 1024 / 1024 // MiB + } else { + logger.Error("Unable to get memory info", logger.Fields{ + field.Error: err, + }) + } + + cpu := runtime.NumCPU() * 1024 + + return int64(cpu), mem +} + +func validateRegisteredAttributes(expectedAttributes, actualAttributes []*ecsmodel.Attribute) error { + var err error + expectedAttributesMap := attributesToMap(expectedAttributes) + actualAttributesMap := attributesToMap(actualAttributes) + missingAttributes, err := findMissingAttributes(expectedAttributesMap, actualAttributesMap) + if err != nil { + msg := strings.Join(missingAttributes, ",") + logger.Error("Error registering attributes", logger.Fields{ + field.Error: err, + "missingAttributes": msg, + }) + } + return err +} + +func (client *ecsClient) getAdditionalAttributes() []*ecsmodel.Attribute { + attrs := []*ecsmodel.Attribute{ + { + Name: aws.String(osTypeAttrName), + Value: aws.String(client.configAccessor.OSType()), + }, + { + Name: aws.String(osFamilyAttrName), + Value: aws.String(client.configAccessor.OSFamily()), + }, + } + // Send CPU arch attribute directly when running on external capacity. When running on EC2 or Fargate launch type, + // this is not needed since the CPU arch is reported via instance identity document in those cases. + if client.configAccessor.External() { + attrs = append(attrs, &ecsmodel.Attribute{ + Name: aws.String(cpuArchAttrName), + Value: aws.String(getCPUArch()), + }) + } + return attrs +} + +func (client *ecsClient) getOutpostAttribute(outpostARN string) []*ecsmodel.Attribute { + if len(outpostARN) > 0 { + return []*ecsmodel.Attribute{ + { + Name: aws.String("ecs.outpost-arn"), + Value: aws.String(outpostARN), + }, + } + } + return []*ecsmodel.Attribute{} +} + +func (client *ecsClient) getCustomAttributes() []*ecsmodel.Attribute { + var attributes []*ecsmodel.Attribute + for attribute, value := range client.configAccessor.InstanceAttributes() { + attributes = append(attributes, &ecsmodel.Attribute{ + Name: aws.String(attribute), + Value: aws.String(value), + }) + } + return attributes +} + +func (client *ecsClient) SubmitTaskStateChange(change ecs.TaskStateChange) error { + if change.Attachment != nil && client.stscAttachmentCustomRetryBackoff != nil { + retryFunc := func() error { + err := client.submitTaskStateChange(change) + if err == nil { + return nil + } + return submitStateCustomRetriableError(err) + } + return client.stscAttachmentCustomRetryBackoff(retryFunc) + } + return client.submitTaskStateChange(change) +} + +func (client *ecsClient) submitTaskStateChange(change ecs.TaskStateChange) error { + if change.Attachment != nil { + // Confirm attachment by submitting attachment state change via SubmitTaskStateChange API (specifically in + // the input's Attachments field). + var attachments []*ecsmodel.AttachmentStateChange + eniStatus := change.Attachment.Status.String() + attachments = []*ecsmodel.AttachmentStateChange{ + { + AttachmentArn: aws.String(change.Attachment.AttachmentARN), + Status: aws.String(eniStatus), + }, + } + + _, err := client.submitStateChangeClient.SubmitTaskStateChange(&ecsmodel.SubmitTaskStateChangeInput{ + Cluster: aws.String(client.configAccessor.Cluster()), + Task: aws.String(change.TaskARN), + Attachments: attachments, + }) + if err != nil { + logger.Warn("Could not submit task state change associated with confirming attachment", + logger.Fields{ + field.Error: err, + "attachmentARN": change.Attachment.AttachmentARN, + field.Status: eniStatus, + }) + return err + } + + return nil + } + + req := ecsmodel.SubmitTaskStateChangeInput{ + Cluster: aws.String(client.configAccessor.Cluster()), + Task: aws.String(change.TaskARN), + Status: aws.String(change.Status.BackendStatus()), + Reason: aws.String(trimString(change.Reason, ecsMaxTaskReasonLength)), + PullStartedAt: change.PullStartedAt, + PullStoppedAt: change.PullStoppedAt, + ExecutionStoppedAt: change.ExecutionStoppedAt, + ManagedAgents: change.ManagedAgents, + Containers: formatContainers(change.Containers, client.shouldExcludeIPv6PortBinding, change.TaskARN), + } + + _, err := client.submitStateChangeClient.SubmitTaskStateChange(&req) + if err != nil { + logger.Warn("Could not submit task state change", logger.Fields{ + field.Error: err, + "taskStateChange": change.String(), + }) + return err + } + + return nil +} + +func (client *ecsClient) SubmitContainerStateChange(change ecs.ContainerStateChange) error { + input := ecsmodel.SubmitContainerStateChangeInput{ + Cluster: aws.String(client.configAccessor.Cluster()), + ContainerName: aws.String(change.ContainerName), + Task: aws.String(change.TaskArn), + } + + if change.RuntimeID != "" { + input.RuntimeId = aws.String(trimString(change.RuntimeID, ecsMaxRuntimeIDLength)) + } + + if change.Reason != "" { + input.Reason = aws.String(trimString(change.Reason, ecsMaxContainerReasonLength)) + } + + stat := change.Status.String() + if stat == "DEAD" { + stat = apicontainerstatus.ContainerStopped.String() + } + if stat != apicontainerstatus.ContainerStopped.String() && stat != apicontainerstatus.ContainerRunning.String() { + logger.Info("Not submitting unsupported upstream container state", logger.Fields{ + field.ContainerName: change.ContainerName, + field.Status: stat, + field.TaskARN: change.TaskArn, + }) + return nil + } + input.Status = aws.String(stat) + + if change.ExitCode != nil { + exitCode := int64(aws.IntValue(change.ExitCode)) + input.ExitCode = aws.Int64(exitCode) + } + + networkBindings := change.NetworkBindings + if client.shouldExcludeIPv6PortBinding { + networkBindings = excludeIPv6PortBindingFromNetworkBindings(networkBindings, change.ContainerName, + change.TaskArn) + } + input.NetworkBindings = networkBindings + + _, err := client.submitStateChangeClient.SubmitContainerStateChange(&input) + if err != nil { + logger.Warn("Could not submit container state change", logger.Fields{ + field.Error: err, + field.TaskARN: change.TaskArn, + "containerStateChange": change.String(), + }) + return err + } + return nil +} + +func (client *ecsClient) SubmitAttachmentStateChange(change ecs.AttachmentStateChange) error { + if client.sascCustomRetryBackoff != nil { + retryFunc := func() error { + err := client.submitAttachmentStateChange(change) + if err == nil { + return nil + } + return submitStateCustomRetriableError(err) + } + return client.sascCustomRetryBackoff(retryFunc) + } + return client.submitAttachmentStateChange(change) +} + +func (client *ecsClient) submitAttachmentStateChange(change ecs.AttachmentStateChange) error { + attachmentStatus := change.Attachment.GetAttachmentStatus() + + req := ecsmodel.SubmitAttachmentStateChangesInput{ + Cluster: aws.String(client.configAccessor.Cluster()), + Attachments: []*ecsmodel.AttachmentStateChange{ + { + AttachmentArn: aws.String(change.Attachment.GetAttachmentARN()), + Status: aws.String(attachmentStatus.String()), + }, + }, + } + + _, err := client.submitStateChangeClient.SubmitAttachmentStateChanges(&req) + if err != nil { + logger.Warn("Could not submit attachment state change", logger.Fields{ + field.Error: err, + "attachmentStateChange": change.String(), + }) + return err + } + + return nil +} + +func submitStateCustomRetriableError(err error) error { + retry := true + aerr, ok := err.(awserr.Error) + if ok { + switch aerr.Code() { + case ecsmodel.ErrCodeInvalidParameterException: + retry = false + case ecsmodel.ErrCodeAccessDeniedException: + retry = false + case ecsmodel.ErrCodeClientException: + retry = false + } + } + return apierrors.NewRetriableError(apierrors.NewRetriable(retry), err) +} + +func (client *ecsClient) DiscoverPollEndpoint(containerInstanceArn string) (string, error) { + resp, err := client.discoverPollEndpoint(containerInstanceArn) + if err != nil { + return "", err + } + if resp.Endpoint == nil { + return "", errors.New("no endpoint returned; nil") + } + + return aws.StringValue(resp.Endpoint), nil +} + +func (client *ecsClient) DiscoverTelemetryEndpoint(containerInstanceArn string) (string, error) { + resp, err := client.discoverPollEndpoint(containerInstanceArn) + if err != nil { + return "", err + } + if resp.TelemetryEndpoint == nil { + return "", errors.New("no telemetry endpoint returned; nil") + } + + return aws.StringValue(resp.TelemetryEndpoint), nil +} + +func (client *ecsClient) DiscoverServiceConnectEndpoint(containerInstanceArn string) (string, error) { + resp, err := client.discoverPollEndpoint(containerInstanceArn) + if err != nil { + return "", err + } + if resp.ServiceConnectEndpoint == nil { + return "", errors.New("no ServiceConnect endpoint returned; nil") + } + + return aws.StringValue(resp.ServiceConnectEndpoint), nil +} + +func (client *ecsClient) discoverPollEndpoint(containerInstanceArn string) (*ecsmodel.DiscoverPollEndpointOutput, + error) { + // Try getting an entry from the cache. + cachedEndpoint, expired, found := client.pollEndpointCache.Get(containerInstanceArn) + if !expired && found { + // Cache hit and not expired. Return the output. + if output, ok := cachedEndpoint.(*ecsmodel.DiscoverPollEndpointOutput); ok { + logger.Info("Using cached DiscoverPollEndpoint", logger.Fields{ + field.Endpoint: aws.StringValue(output.Endpoint), + field.TelemetryEndpoint: aws.StringValue(output.TelemetryEndpoint), + field.ServiceConnectEndpoint: aws.StringValue(output.ServiceConnectEndpoint), + field.ContainerInstanceARN: containerInstanceArn, + }) + return output, nil + } + } + + // Cache miss or expired, invoke the ECS DiscoverPollEndpoint API. + logger.Debug("Invoking DiscoverPollEndpoint", logger.Fields{ + field.ContainerInstanceARN: containerInstanceArn, + }) + output, err := client.standardClient.DiscoverPollEndpoint(&ecsmodel.DiscoverPollEndpointInput{ + ContainerInstance: &containerInstanceArn, + Cluster: aws.String(client.configAccessor.Cluster()), + }) + if err != nil { + // If we got an error calling the API, fallback to an expired cached endpoint if + // we have it. + if expired { + if output, ok := cachedEndpoint.(*ecsmodel.DiscoverPollEndpointOutput); ok { + logger.Info("Error calling DiscoverPollEndpoint. Using cached-but-expired endpoint as a fallback.", + logger.Fields{ + field.Endpoint: aws.StringValue(output.Endpoint), + field.TelemetryEndpoint: aws.StringValue(output.TelemetryEndpoint), + field.ServiceConnectEndpoint: aws.StringValue(output.ServiceConnectEndpoint), + field.ContainerInstanceARN: containerInstanceArn, + }) + return output, nil + } + } + return nil, err + } + + // Cache the response from ECS. + client.pollEndpointCache.Set(containerInstanceArn, output) + return output, nil +} + +func (client *ecsClient) GetResourceTags(resourceArn string) ([]*ecsmodel.Tag, error) { + output, err := client.standardClient.ListTagsForResource(&ecsmodel.ListTagsForResourceInput{ + ResourceArn: &resourceArn, + }) + if err != nil { + return nil, err + } + return output.Tags, nil +} + +func (client *ecsClient) UpdateContainerInstancesState(instanceARN string, status string) error { + logger.Debug("Invoking UpdateContainerInstancesState", logger.Fields{ + field.Status: status, + field.ContainerInstanceARN: instanceARN, + }) + _, err := client.standardClient.UpdateContainerInstancesState(&ecsmodel.UpdateContainerInstancesStateInput{ + ContainerInstances: []*string{aws.String(instanceARN)}, + Status: aws.String(status), + Cluster: aws.String(client.configAccessor.Cluster()), + }) + return err +} + +func formatContainers(containers []*ecsmodel.ContainerStateChange, shouldExcludeIPv6PortBinding bool, + taskARN string) []*ecsmodel.ContainerStateChange { + var result []*ecsmodel.ContainerStateChange + for _, c := range containers { + if c.RuntimeId != nil { + c.RuntimeId = aws.String(trimString(aws.StringValue(c.RuntimeId), ecsMaxRuntimeIDLength)) + } + if c.Reason != nil { + c.Reason = aws.String(trimString(aws.StringValue(c.Reason), ecsMaxContainerReasonLength)) + } + if c.ImageDigest != nil { + c.ImageDigest = aws.String(trimString(aws.StringValue(c.ImageDigest), ecsMaxImageDigestLength)) + } + if shouldExcludeIPv6PortBinding { + c.NetworkBindings = excludeIPv6PortBindingFromNetworkBindings(c.NetworkBindings, + aws.StringValue(c.ContainerName), taskARN) + } + result = append(result, c) + } + return result +} + +func excludeIPv6PortBindingFromNetworkBindings(networkBindings []*ecsmodel.NetworkBinding, containerName, + taskARN string) []*ecsmodel.NetworkBinding { + var result []*ecsmodel.NetworkBinding + for _, binding := range networkBindings { + if aws.StringValue(binding.BindIP) == "::" { + logger.Debug("Exclude IPv6 port binding", logger.Fields{ + "portBinding": binding, + field.ContainerName: containerName, + field.TaskARN: taskARN, + }) + continue + } + result = append(result, binding) + } + return result +} + +func trimString(inputString string, maxLen int) string { + if len(inputString) > maxLen { + trimmed := inputString[0:maxLen] + return trimmed + } else { + return inputString + } +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client_option.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client_option.go new file mode 100644 index 00000000000..ee9ad5c8686 --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client_option.go @@ -0,0 +1,89 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package ecsclient + +import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/async" +) + +// ECSClientOption allows for configuration of an ecsClient. +type ECSClientOption func(*ecsClient) + +// WithFIPSDetected is an ECSClientOption that configures the +// ecsClient.isFIPSDetected with the value passed as a parameter. +func WithFIPSDetected(val bool) ECSClientOption { + return func(client *ecsClient) { + client.isFIPSDetected = val + } +} + +// WithDiscoverPollEndpointCacheTTL is an ECSClientOption that configures the +// ecsClient.pollEndpointCache.ttl with the value passed as a parameter. +func WithDiscoverPollEndpointCacheTTL(t *async.TTL) ECSClientOption { + return func(client *ecsClient) { + client.pollEndpointCache.SetTTL(t) + } +} + +// WithIPv6PortBindingExcluded is an ECSClientOption that configures the +// ecsClient.shouldExcludeIPv6PortBinding with the value passed as a parameter. +func WithIPv6PortBindingExcluded(val bool) ECSClientOption { + return func(client *ecsClient) { + client.shouldExcludeIPv6PortBinding = val + } +} + +// WithSASCCustomRetryBackoff is an ECSClientOption that configures the +// ecsClient.sascCustomRetryBackoff with the value passed as a parameter. +func WithSASCCustomRetryBackoff(f func(func() error) error) ECSClientOption { + return func(client *ecsClient) { + client.sascCustomRetryBackoff = f + } +} + +// WithSTSCAttachmentCustomRetryBackoff is an ECSClientOption that configures the +// ecsClient.stscAttachmentCustomRetryBackoff with the value passed as a parameter. +func WithSTSCAttachmentCustomRetryBackoff(f func(func() error) error) ECSClientOption { + return func(client *ecsClient) { + client.stscAttachmentCustomRetryBackoff = f + } +} + +// WithDiscoverPollEndpointCache is an ECSClientOption that configures the +// ecsClient.pollEndpointCache with the value passed as a parameter. +// This is especially useful for injecting a test implementation. +func WithDiscoverPollEndpointCache(c async.TTLCache) ECSClientOption { + return func(client *ecsClient) { + client.pollEndpointCache = c + } +} + +// WithStandardClient is an ECSClientOption that configures the +// ecsClient.standardClient with the value passed as a parameter. +// This is especially useful for injecting a test implementation. +func WithStandardClient(s ecs.ECSStandardSDK) ECSClientOption { + return func(client *ecsClient) { + client.standardClient = s + } +} + +// WithSubmitStateChangeClient is an ECSClientOption that configures the +// ecsClient.submitStateChangeClient with the value passed as a parameter. +// This is especially useful for injecting a test implementation. +func WithSubmitStateChangeClient(s ecs.ECSSubmitStateSDK) ECSClientOption { + return func(client *ecsClient) { + client.submitStateChangeClient = s + } +} diff --git a/agent/api/ecsclient/retry_handler.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/retry_handler.go similarity index 100% rename from agent/api/ecsclient/retry_handler.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/retry_handler.go diff --git a/agent/api/ecsclient/utils.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/utils.go similarity index 100% rename from agent/api/ecsclient/utils.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/utils.go diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/generate_mocks.go deleted file mode 100644 index 57e12a724fc..00000000000 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/generate_mocks.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:generate mockgen -destination=mocks/api_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/api ECSDiscoverEndpointSDK - -package api diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/interface.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/interface.go deleted file mode 100644 index 22324f0a367..00000000000 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/interface.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. -package api - -// ECSDiscoverEndpointSDK is an interface with customized ecs client that -// implements the DiscoverPollEndpoint, DiscoverTelemetryEndpoint, and DiscoverServiceConnectEndpoint -type ECSDiscoverEndpointSDK interface { - DiscoverPollEndpoint(containerInstanceArn string) (string, error) - DiscoverTelemetryEndpoint(containerInstanceArn string) (string, error) - DiscoverServiceConnectEndpoint(containerInstanceArn string) (string, error) -} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/utils.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/utils.go index fa4581a5d30..7b08a499af2 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/utils.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/utils/utils.go @@ -15,6 +15,8 @@ package utils import ( "reflect" "strconv" + + "github.com/aws/aws-sdk-go/aws" ) func ZeroOrNil(obj interface{}) bool { @@ -49,3 +51,11 @@ func Uint16SliceToStringSlice(slice []uint16) []*string { } return stringSlice } + +// Int64PtrToIntPtr converts a *int64 to *int. +func Int64PtrToIntPtr(int64ptr *int64) *int { + if int64ptr == nil { + return nil + } + return aws.Int(int(aws.Int64Value(int64ptr))) +} diff --git a/agent/vendor/github.com/docker/docker/pkg/process/doc.go b/agent/vendor/github.com/docker/docker/pkg/process/doc.go deleted file mode 100644 index dae536d7dbb..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/process/doc.go +++ /dev/null @@ -1,3 +0,0 @@ -// Package process provides a set of basic functions to manage individual -// processes. -package process diff --git a/agent/vendor/github.com/docker/docker/pkg/process/process_unix.go b/agent/vendor/github.com/docker/docker/pkg/process/process_unix.go deleted file mode 100644 index daf39236269..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/process/process_unix.go +++ /dev/null @@ -1,82 +0,0 @@ -//go:build !windows -// +build !windows - -package process - -import ( - "bytes" - "fmt" - "os" - "path/filepath" - "runtime" - "strconv" - - "golang.org/x/sys/unix" -) - -// Alive returns true if process with a given pid is running. It only considers -// positive PIDs; 0 (all processes in the current process group), -1 (all processes -// with a PID larger than 1), and negative (-n, all processes in process group -// "n") values for pid are never considered to be alive. -func Alive(pid int) bool { - if pid < 1 { - return false - } - switch runtime.GOOS { - case "darwin": - // OS X does not have a proc filesystem. Use kill -0 pid to judge if the - // process exists. From KILL(2): https://www.freebsd.org/cgi/man.cgi?query=kill&sektion=2&manpath=OpenDarwin+7.2.1 - // - // Sig may be one of the signals specified in sigaction(2) or it may - // be 0, in which case error checking is performed but no signal is - // actually sent. This can be used to check the validity of pid. - err := unix.Kill(pid, 0) - - // Either the PID was found (no error) or we get an EPERM, which means - // the PID exists, but we don't have permissions to signal it. - return err == nil || err == unix.EPERM - default: - _, err := os.Stat(filepath.Join("/proc", strconv.Itoa(pid))) - return err == nil - } -} - -// Kill force-stops a process. It only considers positive PIDs; 0 (all processes -// in the current process group), -1 (all processes with a PID larger than 1), -// and negative (-n, all processes in process group "n") values for pid are -// ignored. Refer to [KILL(2)] for details. -// -// [KILL(2)]: https://man7.org/linux/man-pages/man2/kill.2.html -func Kill(pid int) error { - if pid < 1 { - return fmt.Errorf("invalid PID (%d): only positive PIDs are allowed", pid) - } - err := unix.Kill(pid, unix.SIGKILL) - if err != nil && err != unix.ESRCH { - return err - } - return nil -} - -// Zombie return true if process has a state with "Z". It only considers positive -// PIDs; 0 (all processes in the current process group), -1 (all processes with -// a PID larger than 1), and negative (-n, all processes in process group "n") -// values for pid are ignored. Refer to [PROC(5)] for details. -// -// [PROC(5)]: https://man7.org/linux/man-pages/man5/proc.5.html -func Zombie(pid int) (bool, error) { - if pid < 1 { - return false, nil - } - data, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) - if err != nil { - if os.IsNotExist(err) { - return false, nil - } - return false, err - } - if cols := bytes.SplitN(data, []byte(" "), 4); len(cols) >= 3 && string(cols[2]) == "Z" { - return true, nil - } - return false, nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/process/process_windows.go b/agent/vendor/github.com/docker/docker/pkg/process/process_windows.go deleted file mode 100644 index 26158d09ece..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/process/process_windows.go +++ /dev/null @@ -1,52 +0,0 @@ -package process - -import ( - "os" - - "golang.org/x/sys/windows" -) - -// Alive returns true if process with a given pid is running. -func Alive(pid int) bool { - h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, uint32(pid)) - if err != nil { - return false - } - var c uint32 - err = windows.GetExitCodeProcess(h, &c) - _ = windows.CloseHandle(h) - if err != nil { - // From the GetExitCodeProcess function (processthreadsapi.h) API docs: - // https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-getexitcodeprocess - // - // The GetExitCodeProcess function returns a valid error code defined by the - // application only after the thread terminates. Therefore, an application should - // not use STILL_ACTIVE (259) as an error code (STILL_ACTIVE is a macro for - // STATUS_PENDING (minwinbase.h)). If a thread returns STILL_ACTIVE (259) as - // an error code, then applications that test for that value could interpret it - // to mean that the thread is still running, and continue to test for the - // completion of the thread after the thread has terminated, which could put - // the application into an infinite loop. - return c == uint32(windows.STATUS_PENDING) - } - return true -} - -// Kill force-stops a process. -func Kill(pid int) error { - p, err := os.FindProcess(pid) - if err == nil { - err = p.Kill() - if err != nil && err != os.ErrProcessDone { - return err - } - } - return nil -} - -// Zombie is not supported on Windows. -// -// TODO(thaJeztah): remove once we remove the stubs from pkg/system. -func Zombie(_ int) (bool, error) { - return false, nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/args_windows.go b/agent/vendor/github.com/docker/docker/pkg/system/args_windows.go deleted file mode 100644 index b7c9487a067..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/args_windows.go +++ /dev/null @@ -1,16 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import ( - "strings" - - "golang.org/x/sys/windows" -) - -// EscapeArgs makes a Windows-style escaped command line from a set of arguments -func EscapeArgs(args []string) string { - escapedArgs := make([]string, len(args)) - for i, a := range args { - escapedArgs[i] = windows.EscapeArg(a) - } - return strings.Join(escapedArgs, " ") -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/chtimes.go b/agent/vendor/github.com/docker/docker/pkg/system/chtimes.go deleted file mode 100644 index 6a6bca43eda..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/chtimes.go +++ /dev/null @@ -1,48 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import ( - "os" - "syscall" - "time" - "unsafe" -) - -// Used by Chtimes -var unixEpochTime, unixMaxTime time.Time - -func init() { - unixEpochTime = time.Unix(0, 0) - if unsafe.Sizeof(syscall.Timespec{}.Nsec) == 8 { - // This is a 64 bit timespec - // os.Chtimes limits time to the following - // - // Note that this intentionally sets nsec (not sec), which sets both sec - // and nsec internally in time.Unix(); - // https://github.com/golang/go/blob/go1.19.2/src/time/time.go#L1364-L1380 - unixMaxTime = time.Unix(0, 1<<63-1) - } else { - // This is a 32 bit timespec - unixMaxTime = time.Unix(1<<31-1, 0) - } -} - -// Chtimes changes the access time and modified time of a file at the given path. -// If the modified time is prior to the Unix Epoch (unixMinTime), or after the -// end of Unix Time (unixEpochTime), os.Chtimes has undefined behavior. In this -// case, Chtimes defaults to Unix Epoch, just in case. -func Chtimes(name string, atime time.Time, mtime time.Time) error { - if atime.Before(unixEpochTime) || atime.After(unixMaxTime) { - atime = unixEpochTime - } - - if mtime.Before(unixEpochTime) || mtime.After(unixMaxTime) { - mtime = unixEpochTime - } - - if err := os.Chtimes(name, atime, mtime); err != nil { - return err - } - - // Take platform specific action for setting create time. - return setCTime(name, mtime) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/chtimes_nowindows.go b/agent/vendor/github.com/docker/docker/pkg/system/chtimes_nowindows.go deleted file mode 100644 index 84ae1570513..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/chtimes_nowindows.go +++ /dev/null @@ -1,15 +0,0 @@ -//go:build !windows -// +build !windows - -package system // import "github.com/docker/docker/pkg/system" - -import ( - "time" -) - -// setCTime will set the create time on a file. On Unix, the create -// time is updated as a side effect of setting the modified time, so -// no action is required. -func setCTime(path string, ctime time.Time) error { - return nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/chtimes_windows.go b/agent/vendor/github.com/docker/docker/pkg/system/chtimes_windows.go deleted file mode 100644 index ab478f5c38e..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/chtimes_windows.go +++ /dev/null @@ -1,25 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import ( - "time" - - "golang.org/x/sys/windows" -) - -// setCTime will set the create time on a file. On Windows, this requires -// calling SetFileTime and explicitly including the create time. -func setCTime(path string, ctime time.Time) error { - pathp, err := windows.UTF16PtrFromString(path) - if err != nil { - return err - } - h, err := windows.CreateFile(pathp, - windows.FILE_WRITE_ATTRIBUTES, windows.FILE_SHARE_WRITE, nil, - windows.OPEN_EXISTING, windows.FILE_FLAG_BACKUP_SEMANTICS, 0) - if err != nil { - return err - } - defer windows.Close(h) - c := windows.NsecToFiletime(ctime.UnixNano()) - return windows.SetFileTime(h, &c, nil, nil) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/errors.go b/agent/vendor/github.com/docker/docker/pkg/system/errors.go deleted file mode 100644 index 2573d716222..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/errors.go +++ /dev/null @@ -1,13 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import ( - "errors" -) - -var ( - // ErrNotSupportedPlatform means the platform is not supported. - ErrNotSupportedPlatform = errors.New("platform and architecture is not supported") - - // ErrNotSupportedOperatingSystem means the operating system is not supported. - ErrNotSupportedOperatingSystem = errors.New("operating system is not supported") -) diff --git a/agent/vendor/github.com/docker/docker/pkg/system/filesys.go b/agent/vendor/github.com/docker/docker/pkg/system/filesys.go deleted file mode 100644 index ce5990c914f..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/filesys.go +++ /dev/null @@ -1,19 +0,0 @@ -package system - -import ( - "os" - "path/filepath" - "strings" -) - -// IsAbs is a platform-agnostic wrapper for filepath.IsAbs. -// -// On Windows, golang filepath.IsAbs does not consider a path \windows\system32 -// as absolute as it doesn't start with a drive-letter/colon combination. However, -// in docker we need to verify things such as WORKDIR /windows/system32 in -// a Dockerfile (which gets translated to \windows\system32 when being processed -// by the daemon). This SHOULD be treated as absolute from a docker processing -// perspective. -func IsAbs(path string) bool { - return filepath.IsAbs(path) || strings.HasPrefix(path, string(os.PathSeparator)) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/filesys_unix.go b/agent/vendor/github.com/docker/docker/pkg/system/filesys_unix.go deleted file mode 100644 index 38011294049..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/filesys_unix.go +++ /dev/null @@ -1,17 +0,0 @@ -//go:build !windows -// +build !windows - -package system // import "github.com/docker/docker/pkg/system" - -import "os" - -// MkdirAllWithACL is a wrapper for os.MkdirAll on unix systems. -func MkdirAllWithACL(path string, perm os.FileMode, sddl string) error { - return os.MkdirAll(path, perm) -} - -// MkdirAll creates a directory named path along with any necessary parents, -// with permission specified by attribute perm for all dir created. -func MkdirAll(path string, perm os.FileMode) error { - return os.MkdirAll(path, perm) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/filesys_windows.go b/agent/vendor/github.com/docker/docker/pkg/system/filesys_windows.go deleted file mode 100644 index 92e972ea2e3..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/filesys_windows.go +++ /dev/null @@ -1,135 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import ( - "os" - "regexp" - "syscall" - "unsafe" - - "golang.org/x/sys/windows" -) - -// SddlAdministratorsLocalSystem is local administrators plus NT AUTHORITY\System. -const SddlAdministratorsLocalSystem = "D:P(A;OICI;GA;;;BA)(A;OICI;GA;;;SY)" - -// volumePath is a regular expression to check if a path is a Windows -// volume path (e.g., "\\?\Volume{4c1b02c1-d990-11dc-99ae-806e6f6e6963}" -// or "\\?\Volume{4c1b02c1-d990-11dc-99ae-806e6f6e6963}\"). -var volumePath = regexp.MustCompile(`^\\\\\?\\Volume{[a-z0-9-]+}\\?$`) - -// MkdirAllWithACL is a custom version of os.MkdirAll modified for use on Windows -// so that it is both volume path aware, and can create a directory with -// an appropriate SDDL defined ACL. -func MkdirAllWithACL(path string, _ os.FileMode, sddl string) error { - sa, err := makeSecurityAttributes(sddl) - if err != nil { - return &os.PathError{Op: "mkdirall", Path: path, Err: err} - } - return mkdirall(path, sa) -} - -// MkdirAll is a custom version of os.MkdirAll that is volume path aware for -// Windows. It can be used as a drop-in replacement for os.MkdirAll. -func MkdirAll(path string, _ os.FileMode) error { - return mkdirall(path, nil) -} - -// mkdirall is a custom version of os.MkdirAll modified for use on Windows -// so that it is both volume path aware, and can create a directory with -// a DACL. -func mkdirall(path string, perm *windows.SecurityAttributes) error { - if volumePath.MatchString(path) { - return nil - } - - // The rest of this method is largely copied from os.MkdirAll and should be kept - // as-is to ensure compatibility. - - // Fast path: if we can tell whether path is a directory or file, stop with success or error. - dir, err := os.Stat(path) - if err == nil { - if dir.IsDir() { - return nil - } - return &os.PathError{Op: "mkdir", Path: path, Err: syscall.ENOTDIR} - } - - // Slow path: make sure parent exists and then call Mkdir for path. - i := len(path) - for i > 0 && os.IsPathSeparator(path[i-1]) { // Skip trailing path separator. - i-- - } - - j := i - for j > 0 && !os.IsPathSeparator(path[j-1]) { // Scan backward over element. - j-- - } - - if j > 1 { - // Create parent. - err = mkdirall(fixRootDirectory(path[:j-1]), perm) - if err != nil { - return err - } - } - - // Parent now exists; invoke Mkdir and use its result. - err = mkdirWithACL(path, perm) - if err != nil { - // Handle arguments like "foo/." by - // double-checking that directory doesn't exist. - dir, err1 := os.Lstat(path) - if err1 == nil && dir.IsDir() { - return nil - } - return err - } - return nil -} - -// mkdirWithACL creates a new directory. If there is an error, it will be of -// type *PathError. . -// -// This is a modified and combined version of os.Mkdir and windows.Mkdir -// in golang to cater for creating a directory am ACL permitting full -// access, with inheritance, to any subfolder/file for Built-in Administrators -// and Local System. -func mkdirWithACL(name string, sa *windows.SecurityAttributes) error { - if sa == nil { - return os.Mkdir(name, 0) - } - - namep, err := windows.UTF16PtrFromString(name) - if err != nil { - return &os.PathError{Op: "mkdir", Path: name, Err: err} - } - - err = windows.CreateDirectory(namep, sa) - if err != nil { - return &os.PathError{Op: "mkdir", Path: name, Err: err} - } - return nil -} - -// fixRootDirectory fixes a reference to a drive's root directory to -// have the required trailing slash. -func fixRootDirectory(p string) string { - if len(p) == len(`\\?\c:`) { - if os.IsPathSeparator(p[0]) && os.IsPathSeparator(p[1]) && p[2] == '?' && os.IsPathSeparator(p[3]) && p[5] == ':' { - return p + `\` - } - } - return p -} - -func makeSecurityAttributes(sddl string) (*windows.SecurityAttributes, error) { - var sa windows.SecurityAttributes - sa.Length = uint32(unsafe.Sizeof(sa)) - sa.InheritHandle = 1 - var err error - sa.SecurityDescriptor, err = windows.SecurityDescriptorFromString(sddl) - if err != nil { - return nil, err - } - return &sa, nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/image_os.go b/agent/vendor/github.com/docker/docker/pkg/system/image_os.go deleted file mode 100644 index e3de86be292..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/image_os.go +++ /dev/null @@ -1,10 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" -import ( - "runtime" - "strings" -) - -// IsOSSupported determines if an operating system is supported by the host. -func IsOSSupported(os string) bool { - return strings.EqualFold(runtime.GOOS, os) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/init_windows.go b/agent/vendor/github.com/docker/docker/pkg/system/init_windows.go deleted file mode 100644 index 3c2a43ddbd3..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/init_windows.go +++ /dev/null @@ -1,18 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -var ( - // containerdRuntimeSupported determines if containerd should be the runtime. - containerdRuntimeSupported = false -) - -// InitContainerdRuntime sets whether to use containerd for runtime on Windows. -func InitContainerdRuntime(cdPath string) { - if len(cdPath) > 0 { - containerdRuntimeSupported = true - } -} - -// ContainerdRuntimeSupported returns true if the use of containerd runtime is supported. -func ContainerdRuntimeSupported() bool { - return containerdRuntimeSupported -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/lstat_unix.go b/agent/vendor/github.com/docker/docker/pkg/system/lstat_unix.go deleted file mode 100644 index 654b9f2c9e6..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/lstat_unix.go +++ /dev/null @@ -1,21 +0,0 @@ -//go:build !windows -// +build !windows - -package system // import "github.com/docker/docker/pkg/system" - -import ( - "os" - "syscall" -) - -// Lstat takes a path to a file and returns -// a system.StatT type pertaining to that file. -// -// Throws an error if the file does not exist -func Lstat(path string) (*StatT, error) { - s := &syscall.Stat_t{} - if err := syscall.Lstat(path, s); err != nil { - return nil, &os.PathError{Op: "Lstat", Path: path, Err: err} - } - return fromStatT(s) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/lstat_windows.go b/agent/vendor/github.com/docker/docker/pkg/system/lstat_windows.go deleted file mode 100644 index 359c791d9b6..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/lstat_windows.go +++ /dev/null @@ -1,14 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import "os" - -// Lstat calls os.Lstat to get a fileinfo interface back. -// This is then copied into our own locally defined structure. -func Lstat(path string) (*StatT, error) { - fi, err := os.Lstat(path) - if err != nil { - return nil, err - } - - return fromStatT(&fi) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/meminfo_deprecated.go b/agent/vendor/github.com/docker/docker/pkg/system/meminfo_deprecated.go deleted file mode 100644 index 216519923e0..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/meminfo_deprecated.go +++ /dev/null @@ -1,16 +0,0 @@ -package system - -import "github.com/docker/docker/pkg/meminfo" - -// MemInfo contains memory statistics of the host system. -// -// Deprecated: use [meminfo.Memory]. -type MemInfo = meminfo.Memory - -// ReadMemInfo retrieves memory statistics of the host system and returns a -// MemInfo type. -// -// Deprecated: use [meminfo.Read]. -func ReadMemInfo() (*meminfo.Memory, error) { - return meminfo.Read() -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/mknod.go b/agent/vendor/github.com/docker/docker/pkg/system/mknod.go deleted file mode 100644 index d27152c0f5b..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/mknod.go +++ /dev/null @@ -1,17 +0,0 @@ -//go:build !windows -// +build !windows - -package system // import "github.com/docker/docker/pkg/system" - -import ( - "golang.org/x/sys/unix" -) - -// Mkdev is used to build the value of linux devices (in /dev/) which specifies major -// and minor number of the newly created device special file. -// Linux device nodes are a bit weird due to backwards compat with 16 bit device nodes. -// They are, from low to high: the lower 8 bits of the minor, then 12 bits of the major, -// then the top 12 bits of the minor. -func Mkdev(major int64, minor int64) uint32 { - return uint32(unix.Mkdev(uint32(major), uint32(minor))) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/mknod_freebsd.go b/agent/vendor/github.com/docker/docker/pkg/system/mknod_freebsd.go deleted file mode 100644 index c890be116f7..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/mknod_freebsd.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build freebsd -// +build freebsd - -package system // import "github.com/docker/docker/pkg/system" - -import ( - "golang.org/x/sys/unix" -) - -// Mknod creates a filesystem node (file, device special file or named pipe) named path -// with attributes specified by mode and dev. -func Mknod(path string, mode uint32, dev int) error { - return unix.Mknod(path, mode, uint64(dev)) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/mknod_unix.go b/agent/vendor/github.com/docker/docker/pkg/system/mknod_unix.go deleted file mode 100644 index 4586aad19e6..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/mknod_unix.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !freebsd && !windows -// +build !freebsd,!windows - -package system // import "github.com/docker/docker/pkg/system" - -import ( - "golang.org/x/sys/unix" -) - -// Mknod creates a filesystem node (file, device special file or named pipe) named path -// with attributes specified by mode and dev. -func Mknod(path string, mode uint32, dev int) error { - return unix.Mknod(path, mode, dev) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/mknod_windows.go b/agent/vendor/github.com/docker/docker/pkg/system/mknod_windows.go deleted file mode 100644 index ec89d7a15ea..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/mknod_windows.go +++ /dev/null @@ -1,11 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -// Mknod is not implemented on Windows. -func Mknod(path string, mode uint32, dev int) error { - return ErrNotSupportedPlatform -} - -// Mkdev is not implemented on Windows. -func Mkdev(major int64, minor int64) uint32 { - panic("Mkdev not implemented on Windows.") -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/path_deprecated.go b/agent/vendor/github.com/docker/docker/pkg/system/path_deprecated.go deleted file mode 100644 index 5c95026c3d1..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/path_deprecated.go +++ /dev/null @@ -1,18 +0,0 @@ -package system - -const defaultUnixPathEnv = "/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" - -// DefaultPathEnv is unix style list of directories to search for -// executables. Each directory is separated from the next by a colon -// ':' character . -// For Windows containers, an empty string is returned as the default -// path will be set by the container, and Docker has no context of what the -// default path should be. -// -// Deprecated: use oci.DefaultPathEnv -func DefaultPathEnv(os string) string { - if os == "windows" { - return "" - } - return defaultUnixPathEnv -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/process_deprecated.go b/agent/vendor/github.com/docker/docker/pkg/system/process_deprecated.go deleted file mode 100644 index 7b9f19acd5f..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/process_deprecated.go +++ /dev/null @@ -1,27 +0,0 @@ -//go:build linux || freebsd || darwin || windows -// +build linux freebsd darwin windows - -package system - -import "github.com/docker/docker/pkg/process" - -var ( - // IsProcessAlive returns true if process with a given pid is running. - // - // Deprecated: use [process.Alive]. - IsProcessAlive = process.Alive - - // IsProcessZombie return true if process has a state with "Z" - // - // Deprecated: use [process.Zombie]. - // - // TODO(thaJeztah): remove the Windows implementation in process once we remove this stub. - IsProcessZombie = process.Zombie -) - -// KillProcess force-stops a process. -// -// Deprecated: use [process.Kill]. -func KillProcess(pid int) { - _ = process.Kill(pid) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/stat_bsd.go b/agent/vendor/github.com/docker/docker/pkg/system/stat_bsd.go deleted file mode 100644 index 8e61d820f02..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/stat_bsd.go +++ /dev/null @@ -1,16 +0,0 @@ -//go:build freebsd || netbsd -// +build freebsd netbsd - -package system // import "github.com/docker/docker/pkg/system" - -import "syscall" - -// fromStatT converts a syscall.Stat_t type to a system.Stat_t type -func fromStatT(s *syscall.Stat_t) (*StatT, error) { - return &StatT{size: s.Size, - mode: uint32(s.Mode), - uid: s.Uid, - gid: s.Gid, - rdev: uint64(s.Rdev), - mtim: s.Mtimespec}, nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/stat_darwin.go b/agent/vendor/github.com/docker/docker/pkg/system/stat_darwin.go deleted file mode 100644 index c1c0ee9f386..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/stat_darwin.go +++ /dev/null @@ -1,13 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import "syscall" - -// fromStatT converts a syscall.Stat_t type to a system.Stat_t type -func fromStatT(s *syscall.Stat_t) (*StatT, error) { - return &StatT{size: s.Size, - mode: uint32(s.Mode), - uid: s.Uid, - gid: s.Gid, - rdev: uint64(s.Rdev), - mtim: s.Mtimespec}, nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/stat_linux.go b/agent/vendor/github.com/docker/docker/pkg/system/stat_linux.go deleted file mode 100644 index 3ac02393f0a..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/stat_linux.go +++ /dev/null @@ -1,20 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import "syscall" - -// fromStatT converts a syscall.Stat_t type to a system.Stat_t type -func fromStatT(s *syscall.Stat_t) (*StatT, error) { - return &StatT{size: s.Size, - mode: s.Mode, - uid: s.Uid, - gid: s.Gid, - // the type is 32bit on mips - rdev: uint64(s.Rdev), //nolint: unconvert - mtim: s.Mtim}, nil -} - -// FromStatT converts a syscall.Stat_t type to a system.Stat_t type -// This is exposed on Linux as pkg/archive/changes uses it. -func FromStatT(s *syscall.Stat_t) (*StatT, error) { - return fromStatT(s) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/stat_openbsd.go b/agent/vendor/github.com/docker/docker/pkg/system/stat_openbsd.go deleted file mode 100644 index 756b92d1e6c..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/stat_openbsd.go +++ /dev/null @@ -1,13 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import "syscall" - -// fromStatT converts a syscall.Stat_t type to a system.Stat_t type -func fromStatT(s *syscall.Stat_t) (*StatT, error) { - return &StatT{size: s.Size, - mode: uint32(s.Mode), - uid: s.Uid, - gid: s.Gid, - rdev: uint64(s.Rdev), - mtim: s.Mtim}, nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/stat_unix.go b/agent/vendor/github.com/docker/docker/pkg/system/stat_unix.go deleted file mode 100644 index a45ffddf750..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/stat_unix.go +++ /dev/null @@ -1,67 +0,0 @@ -//go:build !windows -// +build !windows - -package system // import "github.com/docker/docker/pkg/system" - -import ( - "os" - "syscall" -) - -// StatT type contains status of a file. It contains metadata -// like permission, owner, group, size, etc about a file. -type StatT struct { - mode uint32 - uid uint32 - gid uint32 - rdev uint64 - size int64 - mtim syscall.Timespec -} - -// Mode returns file's permission mode. -func (s StatT) Mode() uint32 { - return s.mode -} - -// UID returns file's user id of owner. -func (s StatT) UID() uint32 { - return s.uid -} - -// GID returns file's group id of owner. -func (s StatT) GID() uint32 { - return s.gid -} - -// Rdev returns file's device ID (if it's special file). -func (s StatT) Rdev() uint64 { - return s.rdev -} - -// Size returns file's size. -func (s StatT) Size() int64 { - return s.size -} - -// Mtim returns file's last modification time. -func (s StatT) Mtim() syscall.Timespec { - return s.mtim -} - -// IsDir reports whether s describes a directory. -func (s StatT) IsDir() bool { - return s.mode&syscall.S_IFDIR != 0 -} - -// Stat takes a path to a file and returns -// a system.StatT type pertaining to that file. -// -// Throws an error if the file does not exist -func Stat(path string) (*StatT, error) { - s := &syscall.Stat_t{} - if err := syscall.Stat(path, s); err != nil { - return nil, &os.PathError{Op: "Stat", Path: path, Err: err} - } - return fromStatT(s) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/stat_windows.go b/agent/vendor/github.com/docker/docker/pkg/system/stat_windows.go deleted file mode 100644 index 0ff3af2fa17..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/stat_windows.go +++ /dev/null @@ -1,49 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import ( - "os" - "time" -) - -// StatT type contains status of a file. It contains metadata -// like permission, size, etc about a file. -type StatT struct { - mode os.FileMode - size int64 - mtim time.Time -} - -// Size returns file's size. -func (s StatT) Size() int64 { - return s.size -} - -// Mode returns file's permission mode. -func (s StatT) Mode() os.FileMode { - return s.mode -} - -// Mtim returns file's last modification time. -func (s StatT) Mtim() time.Time { - return s.mtim -} - -// Stat takes a path to a file and returns -// a system.StatT type pertaining to that file. -// -// Throws an error if the file does not exist -func Stat(path string) (*StatT, error) { - fi, err := os.Stat(path) - if err != nil { - return nil, err - } - return fromStatT(&fi) -} - -// fromStatT converts a os.FileInfo type to a system.StatT type -func fromStatT(fi *os.FileInfo) (*StatT, error) { - return &StatT{ - size: (*fi).Size(), - mode: (*fi).Mode(), - mtim: (*fi).ModTime()}, nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/utimes_unix.go b/agent/vendor/github.com/docker/docker/pkg/system/utimes_unix.go deleted file mode 100644 index 2768750a00b..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/utimes_unix.go +++ /dev/null @@ -1,25 +0,0 @@ -//go:build linux || freebsd -// +build linux freebsd - -package system // import "github.com/docker/docker/pkg/system" - -import ( - "syscall" - - "golang.org/x/sys/unix" -) - -// LUtimesNano is used to change access and modification time of the specified path. -// It's used for symbol link file because unix.UtimesNano doesn't support a NOFOLLOW flag atm. -func LUtimesNano(path string, ts []syscall.Timespec) error { - uts := []unix.Timespec{ - unix.NsecToTimespec(syscall.TimespecToNsec(ts[0])), - unix.NsecToTimespec(syscall.TimespecToNsec(ts[1])), - } - err := unix.UtimesNanoAt(unix.AT_FDCWD, path, uts, unix.AT_SYMLINK_NOFOLLOW) - if err != nil && err != unix.ENOSYS { - return err - } - - return nil -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/utimes_unsupported.go b/agent/vendor/github.com/docker/docker/pkg/system/utimes_unsupported.go deleted file mode 100644 index bfed4af0325..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/utimes_unsupported.go +++ /dev/null @@ -1,11 +0,0 @@ -//go:build !linux && !freebsd -// +build !linux,!freebsd - -package system // import "github.com/docker/docker/pkg/system" - -import "syscall" - -// LUtimesNano is only supported on linux and freebsd. -func LUtimesNano(path string, ts []syscall.Timespec) error { - return ErrNotSupportedPlatform -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/xattrs_linux.go b/agent/vendor/github.com/docker/docker/pkg/system/xattrs_linux.go deleted file mode 100644 index 95b609fe7a8..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/xattrs_linux.go +++ /dev/null @@ -1,37 +0,0 @@ -package system // import "github.com/docker/docker/pkg/system" - -import "golang.org/x/sys/unix" - -// Lgetxattr retrieves the value of the extended attribute identified by attr -// and associated with the given path in the file system. -// It will returns a nil slice and nil error if the xattr is not set. -func Lgetxattr(path string, attr string) ([]byte, error) { - // Start with a 128 length byte array - dest := make([]byte, 128) - sz, errno := unix.Lgetxattr(path, attr, dest) - - for errno == unix.ERANGE { - // Buffer too small, use zero-sized buffer to get the actual size - sz, errno = unix.Lgetxattr(path, attr, []byte{}) - if errno != nil { - return nil, errno - } - dest = make([]byte, sz) - sz, errno = unix.Lgetxattr(path, attr, dest) - } - - switch { - case errno == unix.ENODATA: - return nil, nil - case errno != nil: - return nil, errno - } - - return dest[:sz], nil -} - -// Lsetxattr sets the value of the extended attribute identified by attr -// and associated with the given path in the file system. -func Lsetxattr(path string, attr string, data []byte, flags int) error { - return unix.Lsetxattr(path, attr, data, flags) -} diff --git a/agent/vendor/github.com/docker/docker/pkg/system/xattrs_unsupported.go b/agent/vendor/github.com/docker/docker/pkg/system/xattrs_unsupported.go deleted file mode 100644 index b165a5dbfe9..00000000000 --- a/agent/vendor/github.com/docker/docker/pkg/system/xattrs_unsupported.go +++ /dev/null @@ -1,14 +0,0 @@ -//go:build !linux -// +build !linux - -package system // import "github.com/docker/docker/pkg/system" - -// Lgetxattr is not supported on platforms other than linux. -func Lgetxattr(path string, attr string) ([]byte, error) { - return nil, ErrNotSupportedPlatform -} - -// Lsetxattr is not supported on platforms other than linux. -func Lsetxattr(path string, attr string, data []byte, flags int) error { - return ErrNotSupportedPlatform -} diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index b4a87f43887..d51603b72cd 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -9,7 +9,6 @@ github.com/aws/amazon-ecs-agent/ecs-agent/acs/client github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs github.com/aws/amazon-ecs-agent/ecs-agent/acs/session github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst -github.com/aws/amazon-ecs-agent/ecs-agent/api github.com/aws/amazon-ecs-agent/ecs-agent/api/appnet github.com/aws/amazon-ecs-agent/ecs-agent/api/appnet/mocks github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment @@ -17,6 +16,7 @@ github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource github.com/aws/amazon-ecs-agent/ecs-agent/api/attachment/resource/mocks github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs +github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs github.com/aws/amazon-ecs-agent/ecs-agent/api/errors @@ -233,9 +233,7 @@ github.com/docker/docker/pkg/longpath github.com/docker/docker/pkg/meminfo github.com/docker/docker/pkg/plugins github.com/docker/docker/pkg/plugins/transport -github.com/docker/docker/pkg/process github.com/docker/docker/pkg/rootless -github.com/docker/docker/pkg/system # github.com/docker/go-connections v0.4.0 ## explicit github.com/docker/go-connections/nat diff --git a/ecs-agent/acs/session/session.go b/ecs-agent/acs/session/session.go index cfd8299bc2c..7cc815321ca 100644 --- a/ecs-agent/acs/session/session.go +++ b/ecs-agent/acs/session/session.go @@ -23,7 +23,7 @@ import ( "strings" "time" - "github.com/aws/amazon-ecs-agent/ecs-agent/api" + "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs" rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -72,7 +72,7 @@ type session struct { containerInstanceARN string cluster string credentialsProvider *credentials.Credentials - discoverEndpointClient api.ECSDiscoverEndpointSDK + ecsClient ecs.ECSClient inactiveInstanceCB func() agentVersion string agentHash string @@ -104,7 +104,7 @@ type session struct { // NewSession creates a new Session. func NewSession(containerInstanceARN string, cluster string, - discoverEndpointClient api.ECSDiscoverEndpointSDK, + ecsClient ecs.ECSClient, credentialsProvider *credentials.Credentials, inactiveInstanceCB func(), clientFactory wsclient.ClientFactory, @@ -130,7 +130,7 @@ func NewSession(containerInstanceARN string, return &session{ containerInstanceARN: containerInstanceARN, cluster: cluster, - discoverEndpointClient: discoverEndpointClient, + ecsClient: ecsClient, credentialsProvider: credentialsProvider, inactiveInstanceCB: inactiveInstanceCB, clientFactory: clientFactory, @@ -220,7 +220,7 @@ func (s *session) Start(ctx context.Context) error { // startSessionOnce creates a session with ACS and handles requests using the passed // in arguments. func (s *session) startSessionOnce(ctx context.Context) error { - acsEndpoint, err := s.discoverEndpointClient.DiscoverPollEndpoint(s.containerInstanceARN) + acsEndpoint, err := s.ecsClient.DiscoverPollEndpoint(s.containerInstanceARN) if err != nil { logger.Error("ACS: Unable to discover poll endpoint", logger.Fields{ field.Error: err, diff --git a/ecs-agent/acs/session/session_test.go b/ecs-agent/acs/session/session_test.go index 77c8cece77b..170cfccd898 100644 --- a/ecs-agent/acs/session/session_test.go +++ b/ecs-agent/acs/session/session_test.go @@ -33,7 +33,7 @@ import ( acsclient "github.com/aws/amazon-ecs-agent/ecs-agent/acs/client" mock_session "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" - mock_api "github.com/aws/amazon-ecs-agent/ecs-agent/api/mocks" + mock_ecs "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/mocks" rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/doctor" @@ -224,8 +224,8 @@ func TestSessionReconnectsOnConnectErrors(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) @@ -250,13 +250,13 @@ func TestSessionReconnectsOnConnectErrors(t *testing.T) { }).Return(time.NewTimer(wsclient.DisconnectTimeout), nil).MinTimes(1), ) acsSession := session{ - containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, - clientFactory: mockClientFactory, - heartbeatTimeout: 20 * time.Millisecond, - heartbeatJitter: 10 * time.Millisecond, - disconnectTimeout: 30 * time.Millisecond, - disconnectJitter: 10 * time.Millisecond, + containerInstanceARN: testconst.ContainerInstanceARN, + ecsClient: ecsClient, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), } @@ -345,8 +345,8 @@ func TestSessionReconnectsWithoutBackoffOnEOFError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) @@ -373,7 +373,7 @@ func TestSessionReconnectsWithoutBackoffOnEOFError(t *testing.T) { ) acsSession := session{ containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, + ecsClient: ecsClient, inactiveInstanceCB: noopFunc, backoff: mockBackoff, clientFactory: mockClientFactory, @@ -394,8 +394,8 @@ func TestSessionReconnectsWithBackoffOnNonEOFError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) @@ -422,15 +422,15 @@ func TestSessionReconnectsWithBackoffOnNonEOFError(t *testing.T) { mockBackoff.EXPECT().Reset().AnyTimes(), ) acsSession := session{ - containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, - inactiveInstanceCB: noopFunc, - backoff: mockBackoff, - clientFactory: mockClientFactory, - heartbeatTimeout: 20 * time.Millisecond, - heartbeatJitter: 10 * time.Millisecond, - disconnectTimeout: 30 * time.Millisecond, - disconnectJitter: 10 * time.Millisecond, + containerInstanceARN: testconst.ContainerInstanceARN, + ecsClient: ecsClient, + inactiveInstanceCB: noopFunc, + backoff: mockBackoff, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, } err := acsSession.Start(ctx) @@ -444,8 +444,8 @@ func TestSessionCallsInactiveInstanceCB(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) @@ -477,7 +477,7 @@ func TestSessionCallsInactiveInstanceCB(t *testing.T) { inactiveInstanceReconnectDelay := 200 * time.Millisecond acsSession := session{ containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, + ecsClient: ecsClient, inactiveInstanceCB: inactiveInstanceCB, clientFactory: mockClientFactory, heartbeatTimeout: 20 * time.Millisecond, @@ -499,8 +499,8 @@ func TestSessionReconnectDelayForInactiveInstanceError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) @@ -537,7 +537,7 @@ func TestSessionReconnectDelayForInactiveInstanceError(t *testing.T) { ) acsSession := session{ containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, + ecsClient: ecsClient, inactiveInstanceCB: noopFunc, clientFactory: mockClientFactory, heartbeatTimeout: 20 * time.Millisecond, @@ -559,8 +559,8 @@ func TestSessionReconnectsOnServeErrors(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) @@ -586,14 +586,14 @@ func TestSessionReconnectsOnServeErrors(t *testing.T) { ) acsSession := session{ - containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, - inactiveInstanceCB: noopFunc, - clientFactory: mockClientFactory, - heartbeatTimeout: 20 * time.Millisecond, - heartbeatJitter: 10 * time.Millisecond, - disconnectTimeout: 30 * time.Millisecond, - disconnectJitter: 10 * time.Millisecond, + containerInstanceARN: testconst.ContainerInstanceARN, + ecsClient: ecsClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), } @@ -608,8 +608,8 @@ func TestSessionStopsWhenContextIsCanceled(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) @@ -631,14 +631,14 @@ func TestSessionStopsWhenContextIsCanceled(t *testing.T) { }).Return(inactiveInstanceError), ) acsSession := session{ - containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, - inactiveInstanceCB: noopFunc, - clientFactory: mockClientFactory, - heartbeatTimeout: 20 * time.Millisecond, - heartbeatJitter: 10 * time.Millisecond, - disconnectTimeout: 30 * time.Millisecond, - disconnectJitter: 10 * time.Millisecond, + containerInstanceARN: testconst.ContainerInstanceARN, + ecsClient: ecsClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), } @@ -653,8 +653,8 @@ func TestSessionStopsWhenContextIsErrorDueToTimeout(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithTimeout(context.Background(), 4*time.Millisecond) defer cancel() @@ -674,7 +674,7 @@ func TestSessionStopsWhenContextIsErrorDueToTimeout(t *testing.T) { acsSession := session{ containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, + ecsClient: ecsClient, inactiveInstanceCB: noopFunc, clientFactory: mockClientFactory, heartbeatTimeout: 20 * time.Millisecond, @@ -694,7 +694,7 @@ func TestSessionReconnectsOnDiscoverPollEndpointError(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) @@ -716,19 +716,19 @@ func TestSessionReconnectsOnDiscoverPollEndpointError(t *testing.T) { gomock.InOrder( // DiscoverPollEndpoint returns an error on its first invocation. - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return("", fmt.Errorf("oops")), + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return("", fmt.Errorf("oops")), // Second invocation returns a success. - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil), + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil), ) acsSession := session{ - containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, - inactiveInstanceCB: noopFunc, - clientFactory: mockClientFactory, - heartbeatTimeout: 20 * time.Millisecond, - heartbeatJitter: 10 * time.Millisecond, - disconnectTimeout: 30 * time.Millisecond, - disconnectJitter: 10 * time.Millisecond, + containerInstanceARN: testconst.ContainerInstanceARN, + ecsClient: ecsClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), } @@ -756,8 +756,8 @@ func TestConnectionIsClosedOnIdle(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -779,14 +779,14 @@ func TestConnectionIsClosedOnIdle(t *testing.T) { mockWsClient.EXPECT().WriteCloseMessage().Return(nil).AnyTimes() mockWsClient.EXPECT().Close().Return(nil).MinTimes(1) acsSession := session{ - containerInstanceARN: testconst.ContainerInstanceARN, - discoverEndpointClient: discoverEndpointClient, - inactiveInstanceCB: noopFunc, - clientFactory: mockClientFactory, - heartbeatTimeout: 20 * time.Millisecond, - heartbeatJitter: 10 * time.Millisecond, - disconnectTimeout: 30 * time.Millisecond, - disconnectJitter: 10 * time.Millisecond, + containerInstanceARN: testconst.ContainerInstanceARN, + ecsClient: ecsClient, + inactiveInstanceCB: noopFunc, + clientFactory: mockClientFactory, + heartbeatTimeout: 20 * time.Millisecond, + heartbeatJitter: 10 * time.Millisecond, + disconnectTimeout: 30 * time.Millisecond, + disconnectJitter: 10 * time.Millisecond, backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), } @@ -815,7 +815,7 @@ func TestSessionDoesntLeakGoroutines(t *testing.T) { defer ctrl.Finish() payloadMessageHandler := mock_session.NewMockPayloadMessageHandler(ctrl) - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) closeWS := make(chan bool) @@ -835,7 +835,7 @@ func TestSessionDoesntLeakGoroutines(t *testing.T) { }() timesConnected := 0 - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil). + ecsClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil). AnyTimes().Do(func(_ interface{}) { timesConnected++ }) @@ -851,17 +851,17 @@ func TestSessionDoesntLeakGoroutines(t *testing.T) { ended := make(chan bool, 1) go func() { acsSession := session{ - containerInstanceARN: testconst.ContainerInstanceARN, - credentialsProvider: testCreds, - dockerVersion: dockerVersion, - minAgentConfig: testMinAgentConfig, - discoverEndpointClient: discoverEndpointClient, - inactiveInstanceCB: noopFunc, - clientFactory: acsclient.NewACSClientFactory(), - metricsFactory: metricsfactory.NewNopEntryFactory(), - payloadMessageHandler: payloadMessageHandler, - heartbeatTimeout: 1 * time.Second, - doctor: emptyDoctor, + containerInstanceARN: testconst.ContainerInstanceARN, + credentialsProvider: testCreds, + dockerVersion: dockerVersion, + minAgentConfig: testMinAgentConfig, + ecsClient: ecsClient, + inactiveInstanceCB: noopFunc, + clientFactory: acsclient.NewACSClientFactory(), + metricsFactory: metricsfactory.NewNopEntryFactory(), + payloadMessageHandler: payloadMessageHandler, + heartbeatTimeout: 1 * time.Second, + doctor: emptyDoctor, backoff: retry.NewExponentialBackoff(connectionBackoffMin, connectionBackoffMax, connectionBackoffJitter, connectionBackoffMultiplier), } @@ -902,7 +902,7 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { defer ctrl.Finish() credentialsMetadataSetter := mock_session.NewMockCredentialsMetadataSetter(ctrl) - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) closeWS := make(chan bool) fakeServer, serverIn, requestsChan, errChan, err := startFakeACSServer(closeWS) @@ -922,7 +922,7 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { }() // DiscoverPollEndpoint returns the URL for the server that we started. - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil) + ecsClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil) credentialsManager := mock_credentials.NewMockManager(ctrl) @@ -930,7 +930,7 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) { go func() { acsSession := NewSession(testconst.ContainerInstanceARN, testconst.ClusterARN, - discoverEndpointClient, + ecsClient, testCreds, noopFunc, acsclient.NewACSClientFactory(), @@ -1004,8 +1004,8 @@ func TestSessionCorrectlySetsSendCredentials(t *testing.T) { defer ctrl.Finish() const numInvocations = 10 - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) @@ -1021,7 +1021,7 @@ func TestSessionCorrectlySetsSendCredentials(t *testing.T) { acsSession := NewSession(testconst.ContainerInstanceARN, testconst.ClusterARN, - discoverEndpointClient, + ecsClient, nil, noopFunc, mockClientFactory, @@ -1077,8 +1077,8 @@ func TestSessionCorrectlySetsSendCredentials(t *testing.T) { func TestSessionReconnectCorrectlySetsAcsUrl(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) mockBackoff := mock_retry.NewMockBackoff(ctrl) @@ -1121,7 +1121,7 @@ func TestSessionReconnectCorrectlySetsAcsUrl(t *testing.T) { ) acsSession := NewSession(testconst.ContainerInstanceARN, testconst.ClusterARN, - discoverEndpointClient, + ecsClient, nil, noopFunc, mockClientFactory, @@ -1159,7 +1159,7 @@ func TestStartSessionHandlesAttachResourceMessages(t *testing.T) { defer ctrl.Finish() resourceHandler := mock_session.NewMockResourceHandler(ctrl) - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) + ecsClient := mock_ecs.NewMockECSClient(ctrl) ctx, cancel := context.WithCancel(context.Background()) closeWS := make(chan bool) fakeServer, serverIn, requestsChan, errChan, err := startFakeACSServer(closeWS) @@ -1179,13 +1179,13 @@ func TestStartSessionHandlesAttachResourceMessages(t *testing.T) { }() // DiscoverPollEndpoint returns the URL for the server that we started. - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil) + ecsClient.EXPECT().DiscoverPollEndpoint(testconst.ContainerInstanceARN).Return(fakeServer.URL, nil) ended := make(chan bool, 1) go func() { acsSession := NewSession(testconst.ContainerInstanceARN, testconst.ClusterARN, - discoverEndpointClient, + ecsClient, testCreds, noopFunc, acsclient.NewACSClientFactory(), @@ -1247,8 +1247,8 @@ func TestSessionCallsAddUpdateRequestHandlers(t *testing.T) { addUpdateRequestHandlersCalled = true } - discoverEndpointClient := mock_api.NewMockECSDiscoverEndpointSDK(ctrl) - discoverEndpointClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() + ecsClient := mock_ecs.NewMockECSClient(ctrl) + ecsClient.EXPECT().DiscoverPollEndpoint(gomock.Any()).Return(acsURL, nil).AnyTimes() ctx, cancel := context.WithCancel(context.Background()) mockWsClient := mock_wsclient.NewMockClientServer(ctrl) @@ -1270,7 +1270,7 @@ func TestSessionCallsAddUpdateRequestHandlers(t *testing.T) { acsSession := NewSession(testconst.ContainerInstanceARN, testconst.ClusterARN, - discoverEndpointClient, + ecsClient, nil, noopFunc, mockClientFactory, diff --git a/ecs-agent/api/generate_mocks.go b/ecs-agent/api/generate_mocks.go deleted file mode 100644 index 57e12a724fc..00000000000 --- a/ecs-agent/api/generate_mocks.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -//go:generate mockgen -destination=mocks/api_mocks.go -copyright_file=../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/api ECSDiscoverEndpointSDK - -package api diff --git a/ecs-agent/api/interface.go b/ecs-agent/api/interface.go deleted file mode 100644 index 22324f0a367..00000000000 --- a/ecs-agent/api/interface.go +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. -package api - -// ECSDiscoverEndpointSDK is an interface with customized ecs client that -// implements the DiscoverPollEndpoint, DiscoverTelemetryEndpoint, and DiscoverServiceConnectEndpoint -type ECSDiscoverEndpointSDK interface { - DiscoverPollEndpoint(containerInstanceArn string) (string, error) - DiscoverTelemetryEndpoint(containerInstanceArn string) (string, error) - DiscoverServiceConnectEndpoint(containerInstanceArn string) (string, error) -} diff --git a/ecs-agent/api/mocks/api_mocks.go b/ecs-agent/api/mocks/api_mocks.go deleted file mode 100644 index f2a57c4d1c5..00000000000 --- a/ecs-agent/api/mocks/api_mocks.go +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. -// - -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/amazon-ecs-agent/ecs-agent/api (interfaces: ECSDiscoverEndpointSDK) - -// Package mock_api is a generated GoMock package. -package mock_api - -import ( - reflect "reflect" - - gomock "github.com/golang/mock/gomock" -) - -// MockECSDiscoverEndpointSDK is a mock of ECSDiscoverEndpointSDK interface. -type MockECSDiscoverEndpointSDK struct { - ctrl *gomock.Controller - recorder *MockECSDiscoverEndpointSDKMockRecorder -} - -// MockECSDiscoverEndpointSDKMockRecorder is the mock recorder for MockECSDiscoverEndpointSDK. -type MockECSDiscoverEndpointSDKMockRecorder struct { - mock *MockECSDiscoverEndpointSDK -} - -// NewMockECSDiscoverEndpointSDK creates a new mock instance. -func NewMockECSDiscoverEndpointSDK(ctrl *gomock.Controller) *MockECSDiscoverEndpointSDK { - mock := &MockECSDiscoverEndpointSDK{ctrl: ctrl} - mock.recorder = &MockECSDiscoverEndpointSDKMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockECSDiscoverEndpointSDK) EXPECT() *MockECSDiscoverEndpointSDKMockRecorder { - return m.recorder -} - -// DiscoverPollEndpoint mocks base method. -func (m *MockECSDiscoverEndpointSDK) DiscoverPollEndpoint(arg0 string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DiscoverPollEndpoint", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DiscoverPollEndpoint indicates an expected call of DiscoverPollEndpoint. -func (mr *MockECSDiscoverEndpointSDKMockRecorder) DiscoverPollEndpoint(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverPollEndpoint", reflect.TypeOf((*MockECSDiscoverEndpointSDK)(nil).DiscoverPollEndpoint), arg0) -} - -// DiscoverServiceConnectEndpoint mocks base method. -func (m *MockECSDiscoverEndpointSDK) DiscoverServiceConnectEndpoint(arg0 string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DiscoverServiceConnectEndpoint", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DiscoverServiceConnectEndpoint indicates an expected call of DiscoverServiceConnectEndpoint. -func (mr *MockECSDiscoverEndpointSDKMockRecorder) DiscoverServiceConnectEndpoint(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverServiceConnectEndpoint", reflect.TypeOf((*MockECSDiscoverEndpointSDK)(nil).DiscoverServiceConnectEndpoint), arg0) -} - -// DiscoverTelemetryEndpoint mocks base method. -func (m *MockECSDiscoverEndpointSDK) DiscoverTelemetryEndpoint(arg0 string) (string, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DiscoverTelemetryEndpoint", arg0) - ret0, _ := ret[0].(string) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// DiscoverTelemetryEndpoint indicates an expected call of DiscoverTelemetryEndpoint. -func (mr *MockECSDiscoverEndpointSDKMockRecorder) DiscoverTelemetryEndpoint(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DiscoverTelemetryEndpoint", reflect.TypeOf((*MockECSDiscoverEndpointSDK)(nil).DiscoverTelemetryEndpoint), arg0) -} diff --git a/ecs-agent/utils/utils.go b/ecs-agent/utils/utils.go index fa4581a5d30..7b08a499af2 100644 --- a/ecs-agent/utils/utils.go +++ b/ecs-agent/utils/utils.go @@ -15,6 +15,8 @@ package utils import ( "reflect" "strconv" + + "github.com/aws/aws-sdk-go/aws" ) func ZeroOrNil(obj interface{}) bool { @@ -49,3 +51,11 @@ func Uint16SliceToStringSlice(slice []uint16) []*string { } return stringSlice } + +// Int64PtrToIntPtr converts a *int64 to *int. +func Int64PtrToIntPtr(int64ptr *int64) *int { + if int64ptr == nil { + return nil + } + return aws.Int(int(aws.Int64Value(int64ptr))) +} diff --git a/ecs-agent/utils/utils_test.go b/ecs-agent/utils/utils_test.go index bcbd390b210..245441091bf 100644 --- a/ecs-agent/utils/utils_test.go +++ b/ecs-agent/utils/utils_test.go @@ -16,6 +16,7 @@ import ( "strconv" "testing" + "github.com/aws/aws-sdk-go/aws" "github.com/stretchr/testify/assert" ) @@ -89,3 +90,20 @@ func TestUint16SliceToStringSlice(t *testing.T) { }) } } + +func TestInt64PtrToIntPtr(t *testing.T) { + testCases := []struct { + input *int64 + expectedOutput *int + name string + }{ + {nil, nil, "nil"}, + {aws.Int64(2147483647), aws.Int(2147483647), "smallest max value type int can hold"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedOutput, Int64PtrToIntPtr(tc.input)) + }) + } +}