diff --git a/br/pkg/aws/BUILD.bazel b/br/pkg/aws/BUILD.bazel index d250a7e757de6..7f325f556a019 100644 --- a/br/pkg/aws/BUILD.bazel +++ b/br/pkg/aws/BUILD.bazel @@ -27,9 +27,11 @@ go_test( name = "aws_test", srcs = ["ebs_test.go"], embed = [":aws"], + shard_count = 3, deps = [ "@com_github_aws_aws_sdk_go//aws", "@com_github_aws_aws_sdk_go//service/ec2", + "@com_github_aws_aws_sdk_go//service/ec2/ec2iface", "@com_github_stretchr_testify//require", ], ) diff --git a/br/pkg/aws/ebs.go b/br/pkg/aws/ebs.go index e457172415b32..f3252cbf46fa5 100644 --- a/br/pkg/aws/ebs.go +++ b/br/pkg/aws/ebs.go @@ -251,6 +251,9 @@ func (e *EC2Session) WaitSnapshotsCreated(snapIDMap map[string]string, progress if *s.State == ec2.SnapshotStateCompleted { log.Info("snapshot completed", zap.String("id", *s.SnapshotId)) totalVolumeSize += *s.VolumeSize + } else if *s.State == ec2.SnapshotStateError { + log.Error("snapshot failed", zap.String("id", *s.SnapshotId), zap.String("error", (*s.StateMessage))) + return 0, errors.Errorf("snapshot %s failed", *s.SnapshotId) } else { log.Debug("snapshot creating...", zap.Stringer("snap", s)) uncompletedSnapshots = append(uncompletedSnapshots, s.SnapshotId) diff --git a/br/pkg/aws/ebs_test.go b/br/pkg/aws/ebs_test.go index e55ea68c86e04..96fcf358ff952 100644 --- a/br/pkg/aws/ebs_test.go +++ b/br/pkg/aws/ebs_test.go @@ -14,10 +14,12 @@ package aws import ( + "context" "testing" awsapi "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go/service/ec2/ec2iface" "github.com/stretchr/testify/require" ) @@ -76,3 +78,127 @@ func TestHandleDescribeVolumesResponse(t *testing.T) { require.Equal(t, int64(4), createdVolumeSize) require.Equal(t, 1, len(unfinishedVolumes)) } + +type mockEC2 struct { + ec2iface.EC2API + output ec2.DescribeSnapshotsOutput +} + +func (m mockEC2) DescribeSnapshots(*ec2.DescribeSnapshotsInput) (*ec2.DescribeSnapshotsOutput, error) { + return &m.output, nil +} + +func NewMockEc2Session(mock mockEC2) *EC2Session { + return &EC2Session{ + ec2: mock, + } +} + +func TestWaitSnapshotsCreated(t *testing.T) { + snapIdMap := map[string]string{ + "vol-1": "snap-1", + "vol-2": "snap-2", + } + + cases := []struct { + desc string + snapshotsOutput ec2.DescribeSnapshotsOutput + expectedSize int64 + expectErr bool + expectTimeout bool + }{ + { + desc: "snapshots are all completed", + snapshotsOutput: ec2.DescribeSnapshotsOutput{ + Snapshots: []*ec2.Snapshot{ + { + SnapshotId: awsapi.String("snap-1"), + VolumeSize: awsapi.Int64(1), + State: awsapi.String(ec2.SnapshotStateCompleted), + }, + { + SnapshotId: awsapi.String("snap-2"), + VolumeSize: awsapi.Int64(2), + State: awsapi.String(ec2.SnapshotStateCompleted), + }, + }, + }, + expectedSize: 3, + expectErr: false, + }, + { + desc: "snapshot failed", + snapshotsOutput: ec2.DescribeSnapshotsOutput{ + Snapshots: []*ec2.Snapshot{ + { + SnapshotId: awsapi.String("snap-1"), + VolumeSize: awsapi.Int64(1), + State: awsapi.String(ec2.SnapshotStateCompleted), + }, + { + SnapshotId: awsapi.String("snap-2"), + State: awsapi.String(ec2.SnapshotStateError), + StateMessage: awsapi.String("snapshot failed"), + }, + }, + }, + expectedSize: 0, + expectErr: true, + }, + { + desc: "snapshots pending", + snapshotsOutput: ec2.DescribeSnapshotsOutput{ + Snapshots: []*ec2.Snapshot{ + { + SnapshotId: awsapi.String("snap-1"), + VolumeSize: awsapi.Int64(1), + State: awsapi.String(ec2.SnapshotStateCompleted), + }, + { + SnapshotId: awsapi.String("snap-2"), + State: awsapi.String(ec2.SnapshotStatePending), + }, + }, + }, + expectTimeout: true, + }, + } + + for _, c := range cases { + e := NewMockEc2Session(mockEC2{ + output: c.snapshotsOutput, + }) + + if c.expectTimeout { + func() { + // We wait 5s before checking snapshots + ctx, cancel := context.WithTimeout(context.Background(), 6) + defer cancel() + + done := make(chan struct{}) + go func() { + _, _ = e.WaitSnapshotsCreated(snapIdMap, nil) + done <- struct{}{} + }() + + select { + case <-done: + t.Fatal("WaitSnapshotsCreated should not return before timeout") + case <-ctx.Done(): + require.True(t, true) + } + }() + + continue + } + + size, err := e.WaitSnapshotsCreated(snapIdMap, nil) + if c.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + require.Equal(t, c.expectedSize, size) + } +}