From 0bab10ff80998e500aa9869f55c52742a5ba832c Mon Sep 17 00:00:00 2001 From: Dane H Lim Date: Tue, 21 Nov 2023 16:53:50 -0800 Subject: [PATCH] Trim managed agent reason + add retries for getting instance identity signature --- .../ecs-agent/api/ecs/client/ecs_client.go | 44 ++++++++++++++++--- ecs-agent/api/ecs/client/ecs_client.go | 44 ++++++++++++++++--- ecs-agent/api/ecs/client/ecs_client_test.go | 25 +++++++++++ 3 files changed, 101 insertions(+), 12 deletions(-) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go index 978304d57cf..8b5ed661e18 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/client/ecs_client.go @@ -294,8 +294,22 @@ func (client *ecsClient) setInstanceIdentity( registerRequest.InstanceIdentityDocument = &instanceIdentityDoc if iidRetrieved { - instanceIdentitySignature, err = client.ec2metadata. - GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + ctx, cancel = context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut) + defer cancel() + err = retry.RetryWithBackoffCtx(ctx, backoff, func() error { + var attemptErr error + logger.Debug("Attempting to get Instance Identity Signature") + instanceIdentitySignature, attemptErr = client.ec2metadata. + GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + if attemptErr != nil { + logger.Debug("Unable to get instance identity signature, retrying", logger.Fields{ + field.Error: attemptErr, + }) + return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr) + } + logger.Debug("Successfully retrieved Instance Identity Signature") + return nil + }) if err != nil { logger.Error("Unable to get instance identity signature", logger.Fields{ field.Error: err, @@ -521,7 +535,7 @@ func (client *ecsClient) submitTaskStateChange(change ecs.TaskStateChange) error PullStartedAt: change.PullStartedAt, PullStoppedAt: change.PullStoppedAt, ExecutionStoppedAt: change.ExecutionStoppedAt, - ManagedAgents: change.ManagedAgents, + ManagedAgents: formatManagedAgents(change.ManagedAgents), Containers: formatContainers(change.Containers, client.shouldExcludeIPv6PortBinding, change.TaskARN), } @@ -752,18 +766,29 @@ func (client *ecsClient) UpdateContainerInstancesState(instanceARN string, statu return err } +func formatManagedAgents(managedAgents []*ecsmodel.ManagedAgentStateChange) []*ecsmodel.ManagedAgentStateChange { + var result []*ecsmodel.ManagedAgentStateChange + for _, m := range managedAgents { + if m.Reason != nil { + m.Reason = trimStringPtr(m.Reason, ecsMaxContainerReasonLength) + } + result = append(result, m) + } + return result +} + func formatContainers(containers []*ecsmodel.ContainerStateChange, shouldExcludeIPv6PortBinding bool, taskARN string) []*ecsmodel.ContainerStateChange { var result []*ecsmodel.ContainerStateChange for _, c := range containers { if c.RuntimeId != nil { - c.RuntimeId = aws.String(trimString(aws.StringValue(c.RuntimeId), ecsMaxRuntimeIDLength)) + c.RuntimeId = trimStringPtr(c.RuntimeId, ecsMaxRuntimeIDLength) } if c.Reason != nil { - c.Reason = aws.String(trimString(aws.StringValue(c.Reason), ecsMaxContainerReasonLength)) + c.Reason = trimStringPtr(c.Reason, ecsMaxContainerReasonLength) } if c.ImageDigest != nil { - c.ImageDigest = aws.String(trimString(aws.StringValue(c.ImageDigest), ecsMaxImageDigestLength)) + c.ImageDigest = trimStringPtr(c.ImageDigest, ecsMaxImageDigestLength) } if shouldExcludeIPv6PortBinding { c.NetworkBindings = excludeIPv6PortBindingFromNetworkBindings(c.NetworkBindings, @@ -791,6 +816,13 @@ func excludeIPv6PortBindingFromNetworkBindings(networkBindings []*ecsmodel.Netwo return result } +func trimStringPtr(inputStringPtr *string, maxLen int) *string { + if inputStringPtr == nil { + return nil + } + return aws.String(trimString(aws.StringValue(inputStringPtr), maxLen)) +} + func trimString(inputString string, maxLen int) string { if len(inputString) > maxLen { trimmed := inputString[0:maxLen] diff --git a/ecs-agent/api/ecs/client/ecs_client.go b/ecs-agent/api/ecs/client/ecs_client.go index 978304d57cf..8b5ed661e18 100644 --- a/ecs-agent/api/ecs/client/ecs_client.go +++ b/ecs-agent/api/ecs/client/ecs_client.go @@ -294,8 +294,22 @@ func (client *ecsClient) setInstanceIdentity( registerRequest.InstanceIdentityDocument = &instanceIdentityDoc if iidRetrieved { - instanceIdentitySignature, err = client.ec2metadata. - GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + ctx, cancel = context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut) + defer cancel() + err = retry.RetryWithBackoffCtx(ctx, backoff, func() error { + var attemptErr error + logger.Debug("Attempting to get Instance Identity Signature") + instanceIdentitySignature, attemptErr = client.ec2metadata. + GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource) + if attemptErr != nil { + logger.Debug("Unable to get instance identity signature, retrying", logger.Fields{ + field.Error: attemptErr, + }) + return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr) + } + logger.Debug("Successfully retrieved Instance Identity Signature") + return nil + }) if err != nil { logger.Error("Unable to get instance identity signature", logger.Fields{ field.Error: err, @@ -521,7 +535,7 @@ func (client *ecsClient) submitTaskStateChange(change ecs.TaskStateChange) error PullStartedAt: change.PullStartedAt, PullStoppedAt: change.PullStoppedAt, ExecutionStoppedAt: change.ExecutionStoppedAt, - ManagedAgents: change.ManagedAgents, + ManagedAgents: formatManagedAgents(change.ManagedAgents), Containers: formatContainers(change.Containers, client.shouldExcludeIPv6PortBinding, change.TaskARN), } @@ -752,18 +766,29 @@ func (client *ecsClient) UpdateContainerInstancesState(instanceARN string, statu return err } +func formatManagedAgents(managedAgents []*ecsmodel.ManagedAgentStateChange) []*ecsmodel.ManagedAgentStateChange { + var result []*ecsmodel.ManagedAgentStateChange + for _, m := range managedAgents { + if m.Reason != nil { + m.Reason = trimStringPtr(m.Reason, ecsMaxContainerReasonLength) + } + result = append(result, m) + } + return result +} + func formatContainers(containers []*ecsmodel.ContainerStateChange, shouldExcludeIPv6PortBinding bool, taskARN string) []*ecsmodel.ContainerStateChange { var result []*ecsmodel.ContainerStateChange for _, c := range containers { if c.RuntimeId != nil { - c.RuntimeId = aws.String(trimString(aws.StringValue(c.RuntimeId), ecsMaxRuntimeIDLength)) + c.RuntimeId = trimStringPtr(c.RuntimeId, ecsMaxRuntimeIDLength) } if c.Reason != nil { - c.Reason = aws.String(trimString(aws.StringValue(c.Reason), ecsMaxContainerReasonLength)) + c.Reason = trimStringPtr(c.Reason, ecsMaxContainerReasonLength) } if c.ImageDigest != nil { - c.ImageDigest = aws.String(trimString(aws.StringValue(c.ImageDigest), ecsMaxImageDigestLength)) + c.ImageDigest = trimStringPtr(c.ImageDigest, ecsMaxImageDigestLength) } if shouldExcludeIPv6PortBinding { c.NetworkBindings = excludeIPv6PortBindingFromNetworkBindings(c.NetworkBindings, @@ -791,6 +816,13 @@ func excludeIPv6PortBindingFromNetworkBindings(networkBindings []*ecsmodel.Netwo return result } +func trimStringPtr(inputStringPtr *string, maxLen int) *string { + if inputStringPtr == nil { + return nil + } + return aws.String(trimString(aws.StringValue(inputStringPtr), maxLen)) +} + func trimString(inputString string, maxLen int) string { if len(inputString) > maxLen { trimmed := inputString[0:maxLen] diff --git a/ecs-agent/api/ecs/client/ecs_client_test.go b/ecs-agent/api/ecs/client/ecs_client_test.go index 095be1f9841..88601c3fb32 100644 --- a/ecs-agent/api/ecs/client/ecs_client_test.go +++ b/ecs-agent/api/ecs/client/ecs_client_test.go @@ -318,6 +318,10 @@ func TestRegisterContainerInstance(t *testing.T) { name: "basic case", mockCfgAccessorOverride: nil, }, + { + name: "retry GetDynamicData", + mockCfgAccessorOverride: nil, + }, { name: "no instance identity doc", mockCfgAccessorOverride: func(cfgAccessor *mock_config.MockAgentConfigAccessor) { @@ -386,6 +390,8 @@ func TestRegisterContainerInstance(t *testing.T) { Return("", errors.New("fake unit test error")), mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource). Return(expectedIID, nil), + mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource). + Return("", errors.New("fake unit test error")), mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource). Return(expectedIIDSig, nil), ) @@ -1361,6 +1367,25 @@ func TestWithIPv6PortBindingExcludedSetFalse(t *testing.T) { assert.NoError(t, err, "Unable to submit container state change") } +func TestTrimStringPtr(t *testing.T) { + const testMaxLen = 32 + testCases := []struct { + inputStringPtr *string + expectedOutput *string + name string + }{ + {nil, nil, "nil"}, + {aws.String("abc"), aws.String("abc"), "input does not exceed max length"}, + {aws.String("abcdefghijklmnopqrstuvwxyz1234567890"), + aws.String("abcdefghijklmnopqrstuvwxyz123456"), "input exceeds max length"}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.expectedOutput, trimStringPtr(tc.inputStringPtr, testMaxLen)) + }) + } +} + func extractTagsMapFromRegisterContainerInstanceInput(req *ecsmodel.RegisterContainerInstanceInput) map[string]string { tagsMap := make(map[string]string, len(req.Tags)) for i := range req.Tags {