Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trim managed agent reason + add retries for getting instance identity signature #4042

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

44 changes: 38 additions & 6 deletions ecs-agent/api/ecs/client/ecs_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,22 @@ func (client *ecsClient) setInstanceIdentity(
registerRequest.InstanceIdentityDocument = &instanceIdentityDoc

if iidRetrieved {
instanceIdentitySignature, err = client.ec2metadata.
GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource)
ctx, cancel = context.WithTimeout(context.Background(), setInstanceIdRetryTimeOut)
defer cancel()
err = retry.RetryWithBackoffCtx(ctx, backoff, func() error {
var attemptErr error
logger.Debug("Attempting to get Instance Identity Signature")
instanceIdentitySignature, attemptErr = client.ec2metadata.
GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource)
if attemptErr != nil {
logger.Debug("Unable to get instance identity signature, retrying", logger.Fields{
field.Error: attemptErr,
})
return apierrors.NewRetriableError(apierrors.NewRetriable(true), attemptErr)
}
logger.Debug("Successfully retrieved Instance Identity Signature")
return nil
})
if err != nil {
logger.Error("Unable to get instance identity signature", logger.Fields{
field.Error: err,
Expand Down Expand Up @@ -521,7 +535,7 @@ func (client *ecsClient) submitTaskStateChange(change ecs.TaskStateChange) error
PullStartedAt: change.PullStartedAt,
PullStoppedAt: change.PullStoppedAt,
ExecutionStoppedAt: change.ExecutionStoppedAt,
ManagedAgents: change.ManagedAgents,
ManagedAgents: formatManagedAgents(change.ManagedAgents),
Containers: formatContainers(change.Containers, client.shouldExcludeIPv6PortBinding, change.TaskARN),
}

Expand Down Expand Up @@ -752,18 +766,29 @@ func (client *ecsClient) UpdateContainerInstancesState(instanceARN string, statu
return err
}

func formatManagedAgents(managedAgents []*ecsmodel.ManagedAgentStateChange) []*ecsmodel.ManagedAgentStateChange {
var result []*ecsmodel.ManagedAgentStateChange
for _, m := range managedAgents {
if m.Reason != nil {
m.Reason = trimStringPtr(m.Reason, ecsMaxContainerReasonLength)
}
result = append(result, m)
}
return result
}

func formatContainers(containers []*ecsmodel.ContainerStateChange, shouldExcludeIPv6PortBinding bool,
taskARN string) []*ecsmodel.ContainerStateChange {
var result []*ecsmodel.ContainerStateChange
for _, c := range containers {
if c.RuntimeId != nil {
c.RuntimeId = aws.String(trimString(aws.StringValue(c.RuntimeId), ecsMaxRuntimeIDLength))
c.RuntimeId = trimStringPtr(c.RuntimeId, ecsMaxRuntimeIDLength)
}
if c.Reason != nil {
c.Reason = aws.String(trimString(aws.StringValue(c.Reason), ecsMaxContainerReasonLength))
c.Reason = trimStringPtr(c.Reason, ecsMaxContainerReasonLength)
}
if c.ImageDigest != nil {
c.ImageDigest = aws.String(trimString(aws.StringValue(c.ImageDigest), ecsMaxImageDigestLength))
c.ImageDigest = trimStringPtr(c.ImageDigest, ecsMaxImageDigestLength)
}
if shouldExcludeIPv6PortBinding {
c.NetworkBindings = excludeIPv6PortBindingFromNetworkBindings(c.NetworkBindings,
Expand Down Expand Up @@ -791,6 +816,13 @@ func excludeIPv6PortBindingFromNetworkBindings(networkBindings []*ecsmodel.Netwo
return result
}

func trimStringPtr(inputStringPtr *string, maxLen int) *string {
if inputStringPtr == nil {
return nil
}
return aws.String(trimString(aws.StringValue(inputStringPtr), maxLen))
}

func trimString(inputString string, maxLen int) string {
if len(inputString) > maxLen {
trimmed := inputString[0:maxLen]
Expand Down
25 changes: 25 additions & 0 deletions ecs-agent/api/ecs/client/ecs_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ func TestRegisterContainerInstance(t *testing.T) {
name: "basic case",
mockCfgAccessorOverride: nil,
},
{
name: "retry GetDynamicData",
mockCfgAccessorOverride: nil,
},
{
name: "no instance identity doc",
mockCfgAccessorOverride: func(cfgAccessor *mock_config.MockAgentConfigAccessor) {
Expand Down Expand Up @@ -386,6 +390,8 @@ func TestRegisterContainerInstance(t *testing.T) {
Return("", errors.New("fake unit test error")),
mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentResource).
Return(expectedIID, nil),
mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).
Return("", errors.New("fake unit test error")),
amogh09 marked this conversation as resolved.
Show resolved Hide resolved
mockEC2Metadata.EXPECT().GetDynamicData(ec2.InstanceIdentityDocumentSignatureResource).
Return(expectedIIDSig, nil),
)
Expand Down Expand Up @@ -1361,6 +1367,25 @@ func TestWithIPv6PortBindingExcludedSetFalse(t *testing.T) {
assert.NoError(t, err, "Unable to submit container state change")
}

func TestTrimStringPtr(t *testing.T) {
const testMaxLen = 32
testCases := []struct {
inputStringPtr *string
expectedOutput *string
name string
}{
{nil, nil, "nil"},
{aws.String("abc"), aws.String("abc"), "input does not exceed max length"},
{aws.String("abcdefghijklmnopqrstuvwxyz1234567890"),
aws.String("abcdefghijklmnopqrstuvwxyz123456"), "input exceeds max length"},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expectedOutput, trimStringPtr(tc.inputStringPtr, testMaxLen))
})
}
}

func extractTagsMapFromRegisterContainerInstanceInput(req *ecsmodel.RegisterContainerInstanceInput) map[string]string {
tagsMap := make(map[string]string, len(req.Tags))
for i := range req.Tags {
Expand Down