diff --git a/controllers/awscluster_controller_test.go b/controllers/awscluster_controller_test.go index 828bfae822..5b7ba216ee 100644 --- a/controllers/awscluster_controller_test.go +++ b/controllers/awscluster_controller_test.go @@ -303,6 +303,11 @@ func TestAWSClusterReconcilerIntegrationTests(t *testing.T) { mockedCreateLBV2Calls(t, e) mockedDescribeInstanceCall(m) mockedDescribeAvailabilityZones(m, []string{"us-east-1c", "us-east-1a"}) + mockedDescribeTargetGroupsCall(t, e) + mockedCreateTargetGroupCall(t, e) + mockedModifyTargetGroupAttributes(t, e) + mockedDescribeListenersCall(t, e) + mockedCreateListenerCall(t, e) } expect(ec2Mock.EXPECT(), elbv2Mock.EXPECT()) diff --git a/controllers/helpers_test.go b/controllers/helpers_test.go index f4511e9508..09d4402af8 100644 --- a/controllers/helpers_test.go +++ b/controllers/helpers_test.go @@ -29,6 +29,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" + "sigs.k8s.io/cluster-api-provider-aws/v2/test/helpers" "sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" "sigs.k8s.io/cluster-api/util/conditions" @@ -39,6 +40,7 @@ const DNSName = "www.google.com" var ( lbName = aws.String("test-cluster-apiserver") lbArn = aws.String("loadbalancer::arn") + tgArn = aws.String("arn::target-group") describeLBInput = &elb.DescribeLoadBalancersInput{ LoadBalancerNames: aws.StringSlice([]string{"test-cluster-apiserver"}), } @@ -291,6 +293,157 @@ func mockedCreateLBV2Calls(t *testing.T, m *mocks.MockELBV2APIMockRecorder) { })).MaxTimes(1) } +func mockedDescribeTargetGroupsCall(t *testing.T, m *mocks.MockELBV2APIMockRecorder) { + t.Helper() + m.DescribeTargetGroups(gomock.Eq(&elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: lbArn, + })). + Return(&elbv2.DescribeTargetGroupsOutput{ + NextMarker: new(string), + TargetGroups: []*elbv2.TargetGroup{ + { + HealthCheckEnabled: aws.Bool(true), + HealthCheckIntervalSeconds: new(int64), + HealthCheckPath: new(string), + HealthCheckPort: new(string), + HealthCheckProtocol: new(string), + HealthCheckTimeoutSeconds: new(int64), + HealthyThresholdCount: new(int64), + IpAddressType: new(string), + LoadBalancerArns: []*string{lbArn}, + Matcher: &elbv2.Matcher{}, + Port: new(int64), + Protocol: new(string), + ProtocolVersion: new(string), + TargetGroupArn: tgArn, + TargetGroupName: new(string), + TargetType: new(string), + UnhealthyThresholdCount: new(int64), + VpcId: new(string), + }}, + }, nil) +} + +func mockedCreateTargetGroupCall(t *testing.T, m *mocks.MockELBV2APIMockRecorder) { + t.Helper() + m.CreateTargetGroup(helpers.PartialMatchCreateTargetGroupInput(t, &elbv2.CreateTargetGroupInput{ + HealthCheckEnabled: aws.Bool(true), + HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), + HealthCheckPort: aws.String(infrav1.DefaultAPIServerPortString), + HealthCheckProtocol: aws.String("TCP"), + HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), + HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), + // Note: this is treated as a prefix with the partial matcher. + Name: aws.String("apiserver-target"), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + Tags: []*elbv2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String("bar-apiserver"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/cluster/test-cluster"), + Value: aws.String("owned"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/role"), + Value: aws.String("apiserver"), + }, + }, + UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), + VpcId: aws.String("vpc-exists"), + })).Return(&elbv2.CreateTargetGroupOutput{ + TargetGroups: []*elbv2.TargetGroup{{ + HealthCheckEnabled: aws.Bool(true), + HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), + HealthCheckPort: aws.String(infrav1.DefaultAPIServerPortString), + HealthCheckProtocol: aws.String("TCP"), + HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), + HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), + LoadBalancerArns: []*string{lbArn}, + Matcher: &elbv2.Matcher{}, + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + TargetGroupArn: tgArn, + TargetGroupName: aws.String("apiserver-target"), + UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), + VpcId: aws.String("vpc-exists"), + }}, + }, nil) +} + +func mockedModifyTargetGroupAttributes(t *testing.T, m *mocks.MockELBV2APIMockRecorder) { + t.Helper() + m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ + TargetGroupArn: tgArn, + Attributes: []*elbv2.TargetGroupAttribute{ + { + Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), + Value: aws.String("false"), + }, + }, + })).Return(nil, nil) +} + +func mockedDescribeListenersCall(t *testing.T, m *mocks.MockELBV2APIMockRecorder) { + t.Helper() + m.DescribeListeners(gomock.Eq(&elbv2.DescribeListenersInput{ + LoadBalancerArn: lbArn, + })). + Return(&elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{{ + DefaultActions: []*elbv2.Action{{ + TargetGroupArn: aws.String("arn::targetgroup-not-found"), + }}, + ListenerArn: aws.String("arn::listener"), + LoadBalancerArn: lbArn, + }}, + }, nil) +} + +func mockedCreateListenerCall(t *testing.T, m *mocks.MockELBV2APIMockRecorder) { + t.Helper() + m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ + DefaultActions: []*elbv2.Action{ + { + TargetGroupArn: tgArn, + Type: aws.String(elbv2.ActionTypeEnumForward), + }, + }, + LoadBalancerArn: lbArn, + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + Tags: []*elbv2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String("test-cluster-apiserver"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/cluster/test-cluster"), + Value: aws.String("owned"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/role"), + Value: aws.String("apiserver"), + }, + }, + })).Return(&elbv2.CreateListenerOutput{ + Listeners: []*elbv2.Listener{ + { + DefaultActions: []*elbv2.Action{ + { + TargetGroupArn: tgArn, + Type: aws.String(elbv2.ActionTypeEnumForward), + }, + }, + ListenerArn: aws.String("listener::arn"), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + }, + }}, nil) +} + func mockedDeleteLBCalls(expectV2Call bool, mv2 *mocks.MockELBV2APIMockRecorder, m *mocks.MockELBAPIMockRecorder) { if expectV2Call { mv2.DescribeLoadBalancers(gomock.Any()).Return(describeLBOutputV2, nil) diff --git a/pkg/cloud/services/elb/loadbalancer.go b/pkg/cloud/services/elb/loadbalancer.go index 234123b4d5..05dc835cb8 100644 --- a/pkg/cloud/services/elb/loadbalancer.go +++ b/pkg/cloud/services/elb/loadbalancer.go @@ -87,7 +87,7 @@ func (s *Service) reconcileV2LB(lbSpec *infrav1.AWSLoadBalancerSpec) error { } // Get default api server spec. - spec, err := s.getAPIServerLBSpec(name, lbSpec) + desiredLB, err := s.getAPIServerLBSpec(name, lbSpec) if err != nil { return err } @@ -97,7 +97,7 @@ func (s *Service) reconcileV2LB(lbSpec *infrav1.AWSLoadBalancerSpec) error { // if elb is not found and owner cluster ControlPlaneEndpoint is already populated, then we should not recreate the elb. return errors.Wrapf(err, "no loadbalancer exists for the AWSCluster %s, the cluster has become unrecoverable and should be deleted manually", s.scope.InfraClusterName()) case IsNotFound(err): - lb, err = s.createLB(spec, lbSpec) + lb, err = s.createLB(desiredLB, lbSpec) if err != nil { s.scope.Error(err, "failed to create LB") return err @@ -112,36 +112,43 @@ func (s *Service) reconcileV2LB(lbSpec *infrav1.AWSLoadBalancerSpec) error { // set up the type for later processing lb.LoadBalancerType = lbSpec.LoadBalancerType if lb.IsManaged(s.scope.Name()) { - if !cmp.Equal(spec.ELBAttributes, lb.ELBAttributes) { - if err := s.configureLBAttributes(lb.ARN, spec.ELBAttributes); err != nil { + // Reconcile the target groups and listeners from the spec and the ones currently attached to the load balancer. + // Pass in the ARN that AWS gave us, as well as the rest of the desired specification. + _, _, err := s.reconcileTargetGroupsAndListeners(lb.ARN, desiredLB, lbSpec) + if err != nil { + return errors.Wrapf(err, "failed to create target groups/listeners for load balancer %q", lb.Name) + } + + if !cmp.Equal(desiredLB.ELBAttributes, lb.ELBAttributes) { + if err := s.configureLBAttributes(lb.ARN, desiredLB.ELBAttributes); err != nil { return err } } - if err := s.reconcileV2LBTags(lb, spec.Tags); err != nil { + if err := s.reconcileV2LBTags(lb, desiredLB.Tags); err != nil { return errors.Wrapf(err, "failed to reconcile tags for apiserver load balancer %q", lb.Name) } - // Reconcile the subnets and availability zones from the spec + // Reconcile the subnets and availability zones from the desiredLB // and the ones currently attached to the load balancer. - if len(lb.SubnetIDs) != len(spec.SubnetIDs) { + if len(lb.SubnetIDs) != len(desiredLB.SubnetIDs) { _, err := s.ELBV2Client.SetSubnets(&elbv2.SetSubnetsInput{ LoadBalancerArn: &lb.ARN, - Subnets: aws.StringSlice(spec.SubnetIDs), + Subnets: aws.StringSlice(desiredLB.SubnetIDs), }) if err != nil { return errors.Wrapf(err, "failed to set subnets for apiserver load balancer '%s'", lb.Name) } } - if len(lb.AvailabilityZones) != len(spec.AvailabilityZones) { - lb.AvailabilityZones = spec.AvailabilityZones + if len(lb.AvailabilityZones) != len(desiredLB.AvailabilityZones) { + lb.AvailabilityZones = desiredLB.AvailabilityZones } - // Reconcile the security groups from the spec and the ones currently attached to the load balancer - if shouldReconcileSGs(s.scope, lb, spec.SecurityGroupIDs) { + // Reconcile the security groups from the desiredLB and the ones currently attached to the load balancer + if shouldReconcileSGs(s.scope, lb, desiredLB.SecurityGroupIDs) { _, err := s.ELBV2Client.SetSecurityGroups(&elbv2.SetSecurityGroupsInput{ LoadBalancerArn: &lb.ARN, - SecurityGroups: aws.StringSlice(spec.SecurityGroupIDs), + SecurityGroups: aws.StringSlice(desiredLB.SecurityGroupIDs), }) if err != nil { return errors.Wrapf(err, "failed to apply security groups to load balancer %q", lb.Name) @@ -388,95 +395,14 @@ func (s *Service) createLB(spec *infrav1.LoadBalancer, lbSpec *infrav1.AWSLoadBa return nil, errors.New("no new network load balancer was created; the returned list is empty") } - // TODO(Skarlso): Add options to set up SSL. - // https://github.com/kubernetes-sigs/cluster-api-provider-aws/issues/3899 - for _, ln := range spec.ELBListeners { - // create the target group first - targetGroupInput := &elbv2.CreateTargetGroupInput{ - Name: aws.String(ln.TargetGroup.Name), - Port: aws.Int64(ln.TargetGroup.Port), - Protocol: aws.String(ln.TargetGroup.Protocol.String()), - VpcId: aws.String(ln.TargetGroup.VpcID), - Tags: input.Tags, - HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), - HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), - HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), - UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), - } - if s.scope.VPC().IsIPv6Enabled() { - targetGroupInput.IpAddressType = aws.String("ipv6") - } - if ln.TargetGroup.HealthCheck != nil { - targetGroupInput.HealthCheckEnabled = aws.Bool(true) - targetGroupInput.HealthCheckProtocol = ln.TargetGroup.HealthCheck.Protocol - targetGroupInput.HealthCheckPort = ln.TargetGroup.HealthCheck.Port - if ln.TargetGroup.HealthCheck.Path != nil { - targetGroupInput.HealthCheckPath = ln.TargetGroup.HealthCheck.Path - } - if ln.TargetGroup.HealthCheck.IntervalSeconds != nil { - targetGroupInput.HealthCheckIntervalSeconds = ln.TargetGroup.HealthCheck.IntervalSeconds - } - if ln.TargetGroup.HealthCheck.TimeoutSeconds != nil { - targetGroupInput.HealthCheckTimeoutSeconds = ln.TargetGroup.HealthCheck.TimeoutSeconds - } - if ln.TargetGroup.HealthCheck.ThresholdCount != nil { - targetGroupInput.HealthyThresholdCount = ln.TargetGroup.HealthCheck.ThresholdCount - } - if ln.TargetGroup.HealthCheck.UnhealthyThresholdCount != nil { - targetGroupInput.UnhealthyThresholdCount = ln.TargetGroup.HealthCheck.UnhealthyThresholdCount - } - } - s.scope.Debug("creating target group", "group", targetGroupInput, "listener", ln) - group, err := s.ELBV2Client.CreateTargetGroup(targetGroupInput) - if err != nil { - return nil, errors.Wrapf(err, "failed to create target group for load balancer") - } - if len(group.TargetGroups) == 0 { - return nil, errors.New("no target group was created; the returned list is empty") - } - - if !lbSpec.PreserveClientIP { - targetGroupAttributeInput := &elbv2.ModifyTargetGroupAttributesInput{ - TargetGroupArn: group.TargetGroups[0].TargetGroupArn, - Attributes: []*elbv2.TargetGroupAttribute{ - { - Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), - Value: aws.String("false"), - }, - }, - } - if _, err := s.ELBV2Client.ModifyTargetGroupAttributes(targetGroupAttributeInput); err != nil { - return nil, errors.Wrapf(err, "failed to modify target group attribute") - } - } - - listenerInput := &elbv2.CreateListenerInput{ - DefaultActions: []*elbv2.Action{ - { - TargetGroupArn: group.TargetGroups[0].TargetGroupArn, - Type: aws.String(elbv2.ActionTypeEnumForward), - }, - }, - LoadBalancerArn: out.LoadBalancers[0].LoadBalancerArn, - Port: aws.Int64(ln.Port), - Protocol: aws.String(string(ln.Protocol)), - Tags: converters.MapToV2Tags(spec.Tags), - } - // Create ClassicELBListeners - listener, err := s.ELBV2Client.CreateListener(listenerInput) - if err != nil { - return nil, errors.Wrap(err, "failed to create listener") - } - if len(listener.Listeners) == 0 { - return nil, errors.New("no listener was created; the returned list is empty") - } - } + // Target Groups and listeners will be reconciled separately s.scope.Info("Created network load balancer", "dns-name", *out.LoadBalancers[0].DNSName) res := spec.DeepCopy() s.scope.Debug("applying load balancer DNS to result", "dns", *out.LoadBalancers[0].DNSName) res.DNSName = *out.LoadBalancers[0].DNSName + res.ARN = *out.LoadBalancers[0].LoadBalancerArn return res, nil } @@ -1604,6 +1530,160 @@ func (s *Service) reconcileV2LBTags(lb *infrav1.LoadBalancer, desiredTags map[st return nil } +// reconcileTargetGroupsAndListeners reconciles a Load Balancer's defined listeners with corresponding AWS Target Groups and Listeners. +// These are combined into a single function since they are tightly integrated. +func (s *Service) reconcileTargetGroupsAndListeners(lbARN string, spec *infrav1.LoadBalancer, lbSpec *infrav1.AWSLoadBalancerSpec) ([]*elbv2.TargetGroup, []*elbv2.Listener, error) { + existingTargetGroups, err := s.ELBV2Client.DescribeTargetGroups( + &elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: aws.String(lbARN), + }) + if err != nil { + s.scope.Error(err, "could not describe target groups for load balancer", "arn", lbARN) + return nil, nil, err + } + + existingListeners, err := s.ELBV2Client.DescribeListeners( + &elbv2.DescribeListenersInput{ + LoadBalancerArn: aws.String(lbARN), + }) + if err != nil { + s.scope.Error(err, "could not describe listeners for load balancer", "arn", lbARN) + } + + createdTargetGroups := make([]*elbv2.TargetGroup, 0, len(spec.ELBListeners)) + createdListeners := make([]*elbv2.Listener, 0, len(spec.ELBListeners)) + + // TODO(Skarlso): Add options to set up SSL. + // https://github.com/kubernetes-sigs/cluster-api-provider-aws/issues/3899 + for _, ln := range spec.ELBListeners { + var group *elbv2.TargetGroup + for _, g := range existingTargetGroups.TargetGroups { + if *g.TargetGroupName == ln.TargetGroup.Name { + group = g + } + } + // create the target group first + if group == nil { + group, err = s.createTargetGroup(ln, spec.Tags) + if err != nil { + return nil, nil, err + } + createdTargetGroups = append(createdTargetGroups, group) + + if !lbSpec.PreserveClientIP { + targetGroupAttributeInput := &elbv2.ModifyTargetGroupAttributesInput{ + TargetGroupArn: group.TargetGroupArn, + Attributes: []*elbv2.TargetGroupAttribute{ + { + Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), + Value: aws.String("false"), + }, + }, + } + if _, err := s.ELBV2Client.ModifyTargetGroupAttributes(targetGroupAttributeInput); err != nil { + return nil, nil, errors.Wrapf(err, "failed to modify target group attribute") + } + } + } + + var listener *elbv2.Listener + for _, l := range existingListeners.Listeners { + if l.DefaultActions != nil && len(l.DefaultActions) > 0 && l.DefaultActions[0].TargetGroupArn == group.TargetGroupArn { + listener = l + } + } + + if listener == nil { + listener, err = s.createListener(ln, group, lbARN, spec.Tags) + if err != nil { + return nil, nil, err + } + } + + createdListeners = append(createdListeners, listener) + } + + return createdTargetGroups, createdListeners, nil +} + +// createListener creates a single Listener. +func (s *Service) createListener(ln infrav1.Listener, group *elbv2.TargetGroup, lbARN string, tags map[string]string) (*elbv2.Listener, error) { + listenerInput := &elbv2.CreateListenerInput{ + DefaultActions: []*elbv2.Action{ + { + TargetGroupArn: group.TargetGroupArn, + Type: aws.String(elbv2.ActionTypeEnumForward), + }, + }, + LoadBalancerArn: aws.String(lbARN), + Port: aws.Int64(ln.Port), + Protocol: aws.String(string(ln.Protocol)), + Tags: converters.MapToV2Tags(tags), + } + // Create ClassicELBListeners + listener, err := s.ELBV2Client.CreateListener(listenerInput) + if err != nil { + return nil, errors.Wrap(err, "failed to create listener") + } + if len(listener.Listeners) == 0 { + return nil, errors.New("no listener was created; the returned list is empty") + } + if len(listener.Listeners) > 1 { + return nil, errors.New("more than one listener created; expected only one") + } + return listener.Listeners[0], nil +} + +// createTargetGroup creates a single Target Group. +func (s *Service) createTargetGroup(ln infrav1.Listener, tags map[string]string) (*elbv2.TargetGroup, error) { + targetGroupInput := &elbv2.CreateTargetGroupInput{ + Name: aws.String(ln.TargetGroup.Name), + Port: aws.Int64(ln.TargetGroup.Port), + Protocol: aws.String(ln.TargetGroup.Protocol.String()), + VpcId: aws.String(ln.TargetGroup.VpcID), + Tags: converters.MapToV2Tags(tags), + HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), + HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), + HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), + UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), + } + if s.scope.VPC().IsIPv6Enabled() { + targetGroupInput.IpAddressType = aws.String("ipv6") + } + if ln.TargetGroup.HealthCheck != nil { + targetGroupInput.HealthCheckEnabled = aws.Bool(true) + targetGroupInput.HealthCheckProtocol = ln.TargetGroup.HealthCheck.Protocol + targetGroupInput.HealthCheckPort = ln.TargetGroup.HealthCheck.Port + if ln.TargetGroup.HealthCheck.Path != nil { + targetGroupInput.HealthCheckPath = ln.TargetGroup.HealthCheck.Path + } + if ln.TargetGroup.HealthCheck.IntervalSeconds != nil { + targetGroupInput.HealthCheckIntervalSeconds = ln.TargetGroup.HealthCheck.IntervalSeconds + } + if ln.TargetGroup.HealthCheck.TimeoutSeconds != nil { + targetGroupInput.HealthCheckTimeoutSeconds = ln.TargetGroup.HealthCheck.TimeoutSeconds + } + if ln.TargetGroup.HealthCheck.ThresholdCount != nil { + targetGroupInput.HealthyThresholdCount = ln.TargetGroup.HealthCheck.ThresholdCount + } + if ln.TargetGroup.HealthCheck.UnhealthyThresholdCount != nil { + targetGroupInput.UnhealthyThresholdCount = ln.TargetGroup.HealthCheck.UnhealthyThresholdCount + } + } + s.scope.Debug("creating target group", "group", targetGroupInput, "listener", ln) + group, err := s.ELBV2Client.CreateTargetGroup(targetGroupInput) + if err != nil { + return nil, errors.Wrapf(err, "failed to create target group for load balancer") + } + if len(group.TargetGroups) == 0 { + return nil, errors.New("no target group was created; the returned list is empty") + } + if len(group.TargetGroups) > 1 { + return nil, errors.New("more than one target group created; expected only one") + } + return group.TargetGroups[0], nil +} + func (s *Service) getHealthCheckTarget() string { controlPlaneELB := s.scope.ControlPlaneLoadBalancer() protocol := &infrav1.ELBProtocolSSL diff --git a/pkg/cloud/services/elb/loadbalancer_test.go b/pkg/cloud/services/elb/loadbalancer_test.go index c680a18b70..f2b4b1dbbe 100644 --- a/pkg/cloud/services/elb/loadbalancer_test.go +++ b/pkg/cloud/services/elb/loadbalancer_test.go @@ -40,6 +40,7 @@ import ( infrav1 "sigs.k8s.io/cluster-api-provider-aws/v2/api/v1beta2" "sigs.k8s.io/cluster-api-provider-aws/v2/pkg/cloud/scope" + "sigs.k8s.io/cluster-api-provider-aws/v2/test/helpers" "sigs.k8s.io/cluster-api-provider-aws/v2/test/mocks" clusterv1 "sigs.k8s.io/cluster-api/api/v1beta1" "sigs.k8s.io/cluster-api/util/conditions" @@ -780,6 +781,7 @@ func TestRegisterInstanceWithAPIServerNLB(t *testing.T) { elbName = "bar-apiserver" elbArn = "arn::apiserver" elbSubnetID = "elb-subnet" + tgArn = "arn::target-group" instanceID = "test-instance" az = "us-west-1a" differentAZ = "us-east-2c" @@ -861,14 +863,14 @@ func TestRegisterInstanceWithAPIServerNLB(t *testing.T) { LoadBalancerArns: aws.StringSlice([]string{elbArn}), Port: aws.Int64(infrav1.DefaultAPIServerPort), Protocol: aws.String("TCP"), - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), TargetGroupName: aws.String("something-generated"), VpcId: aws.String("vpc-id"), }, }, }, nil) m.RegisterTargets(gomock.Eq(&elbv2.RegisterTargetsInput{ - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Targets: []*elbv2.TargetDescription{ { Id: aws.String(instanceID), @@ -964,7 +966,7 @@ func TestRegisterInstanceWithAPIServerNLB(t *testing.T) { LoadBalancerArns: aws.StringSlice([]string{elbArn}), Port: aws.Int64(infrav1.DefaultAPIServerPort), Protocol: aws.String("TCP"), - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), TargetGroupName: aws.String("something-generated"), VpcId: aws.String("vpc-id"), }, @@ -993,7 +995,7 @@ func TestRegisterInstanceWithAPIServerNLB(t *testing.T) { }, }, nil) m.RegisterTargets(gomock.Eq(&elbv2.RegisterTargetsInput{ - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Targets: []*elbv2.TargetDescription{ { Id: aws.String(instanceID), @@ -1202,66 +1204,126 @@ func TestCreateNLB(t *testing.T) { }, }, }, nil) - m.CreateTargetGroup(gomock.Eq(&elbv2.CreateTargetGroupInput{ - Name: aws.String("name"), - Port: aws.Int64(infrav1.DefaultAPIServerPort), - Protocol: aws.String("TCP"), - VpcId: aws.String(vpcID), + }, + check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + t.Helper() + if err != nil { + t.Fatalf("did not expect error: %v", err) + } + if lb.DNSName != dns { + t.Fatalf("DNSName did not equal expected value; was: '%s'", lb.DNSName) + } + }, + }, + { + name: "created with ipv6 vpc", + spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { + return spec + }, + awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { + acl.Spec.NetworkSpec.VPC.IPv6 = &infrav1.IPv6{ + CidrBlock: "2022:1234::/64", + PoolID: "pool-id", + } + return acl + }, + elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { + m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ + Name: aws.String(elbName), + IpAddressType: aws.String("dualstack"), + Scheme: aws.String("internet-facing"), + SecurityGroups: aws.StringSlice([]string{}), + Type: aws.String("network"), + Subnets: aws.StringSlice([]string{clusterSubnetID}), Tags: []*elbv2.Tag{ { Key: aws.String("test"), Value: aws.String("tag"), }, }, - HealthCheckEnabled: aws.Bool(true), - HealthCheckPort: aws.String(infrav1.DefaultAPIServerPortString), - HealthCheckProtocol: aws.String("tcp"), - HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), - UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), - HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), - HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), - })).Return(&elbv2.CreateTargetGroupOutput{ - TargetGroups: []*elbv2.TargetGroup{ + })).Return(&elbv2.CreateLoadBalancerOutput{ + LoadBalancers: []*elbv2.LoadBalancer{ { - TargetGroupArn: aws.String("target-group::arn"), - TargetGroupName: aws.String("name"), - VpcId: aws.String(vpcID), - HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), - UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), - HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), - HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), + LoadBalancerArn: aws.String(elbArn), + LoadBalancerName: aws.String(elbName), + Scheme: aws.String(string(infrav1.ELBSchemeInternetFacing)), + DNSName: aws.String(dns), }, }, }, nil) - m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ - TargetGroupArn: aws.String("target-group::arn"), - Attributes: []*elbv2.TargetGroupAttribute{ - { - Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), - Value: aws.String("false"), - }, - }, - })).Return(nil, nil) - m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ - DefaultActions: []*elbv2.Action{ + }, + check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + t.Helper() + if err != nil { + t.Fatalf("did not expect error: %v", err) + } + if lb.DNSName != dns { + t.Fatalf("DNSName did not equal expected value; was: '%s'", lb.DNSName) + } + }, + }, + { + name: "creating a load balancer fails", + spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { + return spec + }, + awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { + return acl + }, + elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { + m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ + Name: aws.String(elbName), + Scheme: aws.String("internet-facing"), + SecurityGroups: []*string{}, + Type: aws.String("network"), + Subnets: aws.StringSlice([]string{clusterSubnetID}), + Tags: []*elbv2.Tag{ { - TargetGroupArn: aws.String("target-group::arn"), - Type: aws.String(elbv2.ActionTypeEnumForward), + Key: aws.String("test"), + Value: aws.String("tag"), }, }, - LoadBalancerArn: aws.String(elbArn), - Port: aws.Int64(infrav1.DefaultAPIServerPort), - Protocol: aws.String("TCP"), + })).Return(nil, errors.New("nope")) + }, + check: func(t *testing.T, _ *infrav1.LoadBalancer, err error) { + t.Helper() + if err == nil { + t.Fatal("expected error, got nothing") + } + if !strings.Contains(err.Error(), "nope") { + t.Fatalf("expected error to contain 'nope' was instead: %s", err) + } + }, + }, + { + name: "PreserveClientIP is enabled", + spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { + return spec + }, + awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { + acl.Spec.ControlPlaneLoadBalancer.PreserveClientIP = true + return acl + }, + elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { + m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ + Name: aws.String(elbName), + Scheme: aws.String("internet-facing"), + SecurityGroups: aws.StringSlice([]string{}), + Type: aws.String("network"), + Subnets: aws.StringSlice([]string{clusterSubnetID}), Tags: []*elbv2.Tag{ { Key: aws.String("test"), Value: aws.String("tag"), }, }, - })).Return(&elbv2.CreateListenerOutput{ - Listeners: []*elbv2.Listener{ + })).Return(&elbv2.CreateLoadBalancerOutput{ + LoadBalancers: []*elbv2.LoadBalancer{ { - ListenerArn: aws.String("listener::arn"), + LoadBalancerArn: aws.String(elbArn), + LoadBalancerName: aws.String(elbName), + Scheme: aws.String(string(infrav1.ELBSchemeInternetFacing)), + DNSName: aws.String(dns), }, }, }, nil) @@ -1277,31 +1339,28 @@ func TestCreateNLB(t *testing.T) { }, }, { - name: "created with ipv6 vpc", + name: "load balancer is not an NLB scope security groups will be added", spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { + spec.SecurityGroupIDs = []string{"sg-id"} return spec }, awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { - acl.Spec.NetworkSpec.VPC.IPv6 = &infrav1.IPv6{ - CidrBlock: "2022:1234::/64", - PoolID: "pool-id", - } + acl.Spec.ControlPlaneLoadBalancer.LoadBalancerType = infrav1.LoadBalancerTypeALB return acl }, elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ - Name: aws.String(elbName), - IpAddressType: aws.String("dualstack"), - Scheme: aws.String("internet-facing"), - SecurityGroups: aws.StringSlice([]string{}), - Type: aws.String("network"), - Subnets: aws.StringSlice([]string{clusterSubnetID}), + Name: aws.String(elbName), + Scheme: aws.String("internet-facing"), + Type: aws.String("application"), + Subnets: aws.StringSlice([]string{clusterSubnetID}), Tags: []*elbv2.Tag{ { Key: aws.String("test"), Value: aws.String("tag"), }, }, + SecurityGroups: aws.StringSlice([]string{"sg-id"}), })).Return(&elbv2.CreateLoadBalancerOutput{ LoadBalancers: []*elbv2.LoadBalancer{ { @@ -1312,12 +1371,138 @@ func TestCreateNLB(t *testing.T) { }, }, }, nil) + }, + check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + t.Helper() + if err != nil { + t.Fatalf("did not expect error: %v", err) + } + if lb.DNSName != dns { + t.Fatalf("DNSName did not equal expected value; was: '%s'", lb.DNSName) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + elbV2APIMocks := mocks.NewMockELBV2API(mockCtrl) + + scheme, err := setupScheme() + if err != nil { + t.Fatal(err) + } + awsCluster := &infrav1.AWSCluster{ + ObjectMeta: metav1.ObjectMeta{Name: clusterName}, + Spec: infrav1.AWSClusterSpec{ + ControlPlaneLoadBalancer: &infrav1.AWSLoadBalancerSpec{ + Name: aws.String(elbName), + LoadBalancerType: infrav1.LoadBalancerTypeNLB, + }, + NetworkSpec: infrav1.NetworkSpec{ + VPC: infrav1.VPCSpec{ + ID: vpcID, + }, + }, + }, + } + client := fake.NewClientBuilder().WithScheme(scheme).Build() + cluster := tc.awsCluster(*awsCluster) + clusterScope, err := scope.NewClusterScope(scope.ClusterScopeParams{ + Client: client, + Cluster: &clusterv1.Cluster{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: namespace, + Name: clusterName, + }, + }, + AWSCluster: &cluster, + }) + if err != nil { + t.Fatal(err) + } + + tc.elbV2APIMocks(elbV2APIMocks.EXPECT()) + + s := &Service{ + scope: clusterScope, + ELBV2Client: elbV2APIMocks, + } + + loadBalancerSpec := &infrav1.LoadBalancer{ + ARN: elbArn, + Name: elbName, + Scheme: infrav1.ELBSchemeInternetFacing, + Tags: map[string]string{ + "test": "tag", + }, + ELBListeners: []infrav1.Listener{ + { + Protocol: "TCP", + Port: infrav1.DefaultAPIServerPort, + TargetGroup: infrav1.TargetGroupSpec{ + Name: "name", + Port: infrav1.DefaultAPIServerPort, + Protocol: "TCP", + VpcID: vpcID, + HealthCheck: &infrav1.TargetGroupHealthCheck{ + Protocol: aws.String("tcp"), + Port: aws.String(infrav1.DefaultAPIServerPortString), + }, + }, + }, + }, + LoadBalancerType: infrav1.LoadBalancerTypeNLB, + SubnetIDs: []string{clusterSubnetID}, + } + + spec := tc.spec(*loadBalancerSpec) + lb, err := s.createLB(&spec, clusterScope.ControlPlaneLoadBalancer()) + tc.check(t, lb, err) + }) + } +} + +func TestReconcileTargetGroupsAndListeners(t *testing.T) { + const ( + namespace = "foo" + clusterName = "bar" + clusterSubnetID = "subnet-1" + elbName = "bar-apiserver" + elbArn = "arn::apiserver" + tgArn = "arn::target-group" + vpcID = "vpc-id" + dns = "asdf:9999/asdf" + ) + + tests := []struct { + name string + elbV2APIMocks func(m *mocks.MockELBV2APIMockRecorder) + check func(t *testing.T, tgs []*elbv2.TargetGroup, listeners []*elbv2.Listener, err error) + awsCluster func(acl infrav1.AWSCluster) infrav1.AWSCluster + spec func(spec infrav1.LoadBalancer) infrav1.LoadBalancer + }{ + { + name: "main create flow", + spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { + return spec + }, + awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { + return acl + }, + elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { + m.DescribeTargetGroups(gomock.Eq(&elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeTargetGroupsOutput{ + TargetGroups: []*elbv2.TargetGroup{}, + }, nil) m.CreateTargetGroup(gomock.Eq(&elbv2.CreateTargetGroupInput{ - Name: aws.String("name"), - Port: aws.Int64(infrav1.DefaultAPIServerPort), - Protocol: aws.String("TCP"), - VpcId: aws.String(vpcID), - IpAddressType: aws.String("ipv6"), + Name: aws.String("name"), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + VpcId: aws.String(vpcID), Tags: []*elbv2.Tag{ { Key: aws.String("test"), @@ -1334,7 +1519,7 @@ func TestCreateNLB(t *testing.T) { })).Return(&elbv2.CreateTargetGroupOutput{ TargetGroups: []*elbv2.TargetGroup{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), TargetGroupName: aws.String("name"), VpcId: aws.String(vpcID), HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), @@ -1345,7 +1530,7 @@ func TestCreateNLB(t *testing.T) { }, }, nil) m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Attributes: []*elbv2.TargetGroupAttribute{ { Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), @@ -1353,10 +1538,15 @@ func TestCreateNLB(t *testing.T) { }, }, })).Return(nil, nil) + m.DescribeListeners(gomock.Eq(&elbv2.DescribeListenersInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{}, + }, nil) m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ DefaultActions: []*elbv2.Action{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Type: aws.String(elbv2.ActionTypeEnumForward), }, }, @@ -1372,108 +1562,72 @@ func TestCreateNLB(t *testing.T) { })).Return(&elbv2.CreateListenerOutput{ Listeners: []*elbv2.Listener{ { + DefaultActions: []*elbv2.Action{ + { + TargetGroupArn: aws.String(tgArn), + Type: aws.String(elbv2.ActionTypeEnumForward), + }, + }, ListenerArn: aws.String("listener::arn"), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), }, - }, - }, nil) + }}, nil) }, - check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + check: func(t *testing.T, tgs []*elbv2.TargetGroup, listeners []*elbv2.Listener, err error) { t.Helper() if err != nil { t.Fatalf("did not expect error: %v", err) } - if lb.DNSName != dns { - t.Fatalf("DNSName did not equal expected value; was: '%s'", lb.DNSName) + if len(tgs) != 1 { + t.Fatalf("no target groups created") } - }, - }, - { - name: "creating a load balancer fails", - spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { - return spec - }, - awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { - return acl - }, - elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { - m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ - Name: aws.String(elbName), - Scheme: aws.String("internet-facing"), - SecurityGroups: []*string{}, - Type: aws.String("network"), - Subnets: aws.StringSlice([]string{clusterSubnetID}), - Tags: []*elbv2.Tag{ - { - Key: aws.String("test"), - Value: aws.String("tag"), - }, - }, - })).Return(nil, errors.New("nope")) - }, - check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { - t.Helper() - if err == nil { - t.Fatal("expected error, got nothing") + if len(listeners) != 1 { + t.Fatalf("no listeners created") } - if !strings.Contains(err.Error(), "nope") { - t.Fatalf("expected error to contain 'nope' was instead: %s", err) + + if len(listeners[0].DefaultActions) != 1 { + t.Fatalf("no default actions created") + } + + if *tgs[0].TargetGroupArn != *listeners[0].DefaultActions[0].TargetGroupArn { + t.Fatalf("target group and listener did not have matching arns. target group ARN: %q. listener's target group ARN: %q", *tgs[0].TargetGroupArn, *listeners[0].DefaultActions[0].TargetGroupArn) } }, }, { - name: "no health check", + name: "created with ipv6 vpc", spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { - spec.ELBListeners = []infrav1.Listener{ - { - Protocol: "TCP", - Port: infrav1.DefaultAPIServerPort, - TargetGroup: infrav1.TargetGroupSpec{ - Name: "name", - Port: infrav1.DefaultAPIServerPort, - Protocol: "TCP", - VpcID: vpcID, - }, - }, - } return spec }, awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { + acl.Spec.NetworkSpec.VPC.IPv6 = &infrav1.IPv6{ + CidrBlock: "2022:1234::/64", + PoolID: "pool-id", + } return acl }, - elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { - m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ - Name: aws.String(elbName), - Scheme: aws.String("internet-facing"), - SecurityGroups: aws.StringSlice([]string{}), - Type: aws.String("network"), - Subnets: aws.StringSlice([]string{clusterSubnetID}), - Tags: []*elbv2.Tag{ - { - Key: aws.String("test"), - Value: aws.String("tag"), - }, - }, - })).Return(&elbv2.CreateLoadBalancerOutput{ - LoadBalancers: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String(elbArn), - LoadBalancerName: aws.String(elbName), - Scheme: aws.String(string(infrav1.ELBSchemeInternetFacing)), - DNSName: aws.String(dns), - }, - }, + elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { + m.DescribeTargetGroups(gomock.Eq(&elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeTargetGroupsOutput{ + TargetGroups: []*elbv2.TargetGroup{}, }, nil) m.CreateTargetGroup(gomock.Eq(&elbv2.CreateTargetGroupInput{ - Name: aws.String("name"), - Port: aws.Int64(infrav1.DefaultAPIServerPort), - Protocol: aws.String("TCP"), - VpcId: aws.String(vpcID), + Name: aws.String("name"), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + VpcId: aws.String(vpcID), + IpAddressType: aws.String("ipv6"), Tags: []*elbv2.Tag{ { Key: aws.String("test"), Value: aws.String("tag"), }, }, + HealthCheckEnabled: aws.Bool(true), + HealthCheckPort: aws.String(infrav1.DefaultAPIServerPortString), + HealthCheckProtocol: aws.String("tcp"), HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), @@ -1481,18 +1635,19 @@ func TestCreateNLB(t *testing.T) { })).Return(&elbv2.CreateTargetGroupOutput{ TargetGroups: []*elbv2.TargetGroup{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), TargetGroupName: aws.String("name"), VpcId: aws.String(vpcID), HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), + IpAddressType: aws.String("ipv6"), }, }, }, nil) m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Attributes: []*elbv2.TargetGroupAttribute{ { Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), @@ -1500,10 +1655,15 @@ func TestCreateNLB(t *testing.T) { }, }, })).Return(nil, nil) + m.DescribeListeners(gomock.Eq(&elbv2.DescribeListenersInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{}, + }, nil) m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ DefaultActions: []*elbv2.Action{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Type: aws.String(elbv2.ActionTypeEnumForward), }, }, @@ -1524,56 +1684,50 @@ func TestCreateNLB(t *testing.T) { }, }, nil) }, - check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + check: func(t *testing.T, tgs []*elbv2.TargetGroup, _ []*elbv2.Listener, err error) { t.Helper() if err != nil { t.Fatalf("did not expect error: %v", err) } - if lb.DNSName != dns { - t.Fatalf("DNSName did not equal expected value; was: '%s'", lb.DNSName) + tg := tgs[0] + got := *tg.IpAddressType + want := "ipv6" + if got != want { + t.Fatalf("did not set ip address type to ipv6") } }, }, { - name: "PreserveClientIP is enabled", + name: "no health check", spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { + spec.ELBListeners = []infrav1.Listener{ + { + Protocol: "TCP", + Port: infrav1.DefaultAPIServerPort, + TargetGroup: infrav1.TargetGroupSpec{ + Name: "name", + Port: infrav1.DefaultAPIServerPort, + Protocol: "TCP", + VpcID: vpcID, + }, + }, + } return spec }, awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { - acl.Spec.ControlPlaneLoadBalancer.PreserveClientIP = true return acl }, elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { - m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ - Name: aws.String(elbName), - Scheme: aws.String("internet-facing"), - SecurityGroups: aws.StringSlice([]string{}), - Type: aws.String("network"), - Subnets: aws.StringSlice([]string{clusterSubnetID}), - Tags: []*elbv2.Tag{ - { - Key: aws.String("test"), - Value: aws.String("tag"), - }, - }, - })).Return(&elbv2.CreateLoadBalancerOutput{ - LoadBalancers: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String(elbArn), - LoadBalancerName: aws.String(elbName), - Scheme: aws.String(string(infrav1.ELBSchemeInternetFacing)), - DNSName: aws.String(dns), - }, - }, + m.DescribeTargetGroups(gomock.Eq(&elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeTargetGroupsOutput{ + TargetGroups: []*elbv2.TargetGroup{}, }, nil) m.CreateTargetGroup(gomock.Eq(&elbv2.CreateTargetGroupInput{ - HealthCheckEnabled: aws.Bool(true), - HealthCheckPort: aws.String(infrav1.DefaultAPIServerPortString), - HealthCheckProtocol: aws.String("tcp"), - Name: aws.String("name"), - Port: aws.Int64(infrav1.DefaultAPIServerPort), - Protocol: aws.String("TCP"), - VpcId: aws.String(vpcID), + Name: aws.String("name"), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + VpcId: aws.String(vpcID), Tags: []*elbv2.Tag{ { Key: aws.String("test"), @@ -1587,20 +1741,35 @@ func TestCreateNLB(t *testing.T) { })).Return(&elbv2.CreateTargetGroupOutput{ TargetGroups: []*elbv2.TargetGroup{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), TargetGroupName: aws.String("name"), VpcId: aws.String(vpcID), HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), + HealthCheckEnabled: aws.Bool(false), + }, + }, + }, nil) + m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ + TargetGroupArn: aws.String(tgArn), + Attributes: []*elbv2.TargetGroupAttribute{ + { + Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), + Value: aws.String("false"), }, }, + })).Return(nil, nil) + m.DescribeListeners(gomock.Eq(&elbv2.DescribeListenersInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{}, }, nil) m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ DefaultActions: []*elbv2.Action{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Type: aws.String(elbv2.ActionTypeEnumForward), }, }, @@ -1621,48 +1790,32 @@ func TestCreateNLB(t *testing.T) { }, }, nil) }, - check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + check: func(t *testing.T, tgs []*elbv2.TargetGroup, _ []*elbv2.Listener, err error) { t.Helper() if err != nil { t.Fatalf("did not expect error: %v", err) } - if lb.DNSName != dns { - t.Fatalf("DNSName did not equal expected value; was: '%s'", lb.DNSName) + got := *tgs[0].HealthCheckEnabled + want := false + if got != want { + t.Fatalf("health check not disabled on target group") } }, }, { - name: "load balancer is not an NLB scope security groups will be added", + name: "PreserveClientIP is enabled", spec: func(spec infrav1.LoadBalancer) infrav1.LoadBalancer { - spec.SecurityGroupIDs = []string{"sg-id"} return spec }, awsCluster: func(acl infrav1.AWSCluster) infrav1.AWSCluster { - acl.Spec.ControlPlaneLoadBalancer.LoadBalancerType = infrav1.LoadBalancerTypeALB + acl.Spec.ControlPlaneLoadBalancer.PreserveClientIP = true return acl }, elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { - m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ - Name: aws.String(elbName), - Scheme: aws.String("internet-facing"), - Type: aws.String("application"), - Subnets: aws.StringSlice([]string{clusterSubnetID}), - Tags: []*elbv2.Tag{ - { - Key: aws.String("test"), - Value: aws.String("tag"), - }, - }, - SecurityGroups: aws.StringSlice([]string{"sg-id"}), - })).Return(&elbv2.CreateLoadBalancerOutput{ - LoadBalancers: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String(elbArn), - LoadBalancerName: aws.String(elbName), - Scheme: aws.String(string(infrav1.ELBSchemeInternetFacing)), - DNSName: aws.String(dns), - }, - }, + m.DescribeTargetGroups(gomock.Eq(&elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeTargetGroupsOutput{ + TargetGroups: []*elbv2.TargetGroup{}, }, nil) m.CreateTargetGroup(gomock.Eq(&elbv2.CreateTargetGroupInput{ HealthCheckEnabled: aws.Bool(true), @@ -1685,7 +1838,7 @@ func TestCreateNLB(t *testing.T) { })).Return(&elbv2.CreateTargetGroupOutput{ TargetGroups: []*elbv2.TargetGroup{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), TargetGroupName: aws.String("name"), VpcId: aws.String(vpcID), HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), @@ -1695,19 +1848,15 @@ func TestCreateNLB(t *testing.T) { }, }, }, nil) - m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ - TargetGroupArn: aws.String("target-group::arn"), - Attributes: []*elbv2.TargetGroupAttribute{ - { - Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), - Value: aws.String("false"), - }, - }, - })).Return(nil, nil) + m.DescribeListeners(gomock.Eq(&elbv2.DescribeListenersInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{}, + }, nil) m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ DefaultActions: []*elbv2.Action{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Type: aws.String(elbv2.ActionTypeEnumForward), }, }, @@ -1728,13 +1877,18 @@ func TestCreateNLB(t *testing.T) { }, }, nil) }, - check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + check: func(t *testing.T, tgs []*elbv2.TargetGroup, listeners []*elbv2.Listener, err error) { t.Helper() if err != nil { t.Fatalf("did not expect error: %v", err) } - if lb.DNSName != dns { - t.Fatalf("DNSName did not equal expected value; was: '%s'", lb.DNSName) + + if len(tgs) != 1 { + t.Fatalf("did not create target groups") + } + + if len(listeners) != 1 { + t.Fatalf("did not create any listeners") } }, }, @@ -1762,27 +1916,10 @@ func TestCreateNLB(t *testing.T) { return spec }, elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { - m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ - Name: aws.String(elbName), - Scheme: aws.String("internet-facing"), - SecurityGroups: aws.StringSlice([]string{}), - Type: aws.String("network"), - Subnets: aws.StringSlice([]string{clusterSubnetID}), - Tags: []*elbv2.Tag{ - { - Key: aws.String("test"), - Value: aws.String("tag"), - }, - }, - })).Return(&elbv2.CreateLoadBalancerOutput{ - LoadBalancers: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String(elbArn), - LoadBalancerName: aws.String(elbName), - Scheme: aws.String(string(infrav1.ELBSchemeInternetFacing)), - DNSName: aws.String(dns), - }, - }, + m.DescribeTargetGroups(gomock.Eq(&elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeTargetGroupsOutput{ + TargetGroups: []*elbv2.TargetGroup{}, }, nil) m.CreateTargetGroup(gomock.Eq(&elbv2.CreateTargetGroupInput{ Name: aws.String("name"), @@ -1806,12 +1943,12 @@ func TestCreateNLB(t *testing.T) { })).Return(&elbv2.CreateTargetGroupOutput{ TargetGroups: []*elbv2.TargetGroup{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), TargetGroupName: aws.String("name"), VpcId: aws.String(vpcID), HealthCheckEnabled: aws.Bool(true), HealthCheckPort: aws.String(infrav1.DefaultAPIServerPortString), - HealthCheckProtocol: aws.String("http"), + HealthCheckProtocol: aws.String("HTTP"), HealthCheckPath: aws.String("/readyz"), HealthCheckIntervalSeconds: aws.Int64(10), HealthCheckTimeoutSeconds: aws.Int64(5), @@ -1820,10 +1957,15 @@ func TestCreateNLB(t *testing.T) { }, }, }, nil) + m.DescribeListeners(gomock.Eq(&elbv2.DescribeListenersInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{}, + }, nil) m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ DefaultActions: []*elbv2.Action{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Type: aws.String(elbv2.ActionTypeEnumForward), }, }, @@ -1844,7 +1986,7 @@ func TestCreateNLB(t *testing.T) { }, }, nil) m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Attributes: []*elbv2.TargetGroupAttribute{ { Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), @@ -1853,12 +1995,12 @@ func TestCreateNLB(t *testing.T) { }, })).Return(nil, nil) }, - check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + check: func(t *testing.T, tgs []*elbv2.TargetGroup, _ []*elbv2.Listener, err error) { t.Helper() if err != nil { t.Fatalf("did not expect error: %v", err) } - got := *lb.ELBListeners[0].TargetGroup.HealthCheck.Protocol + got := *tgs[0].HealthCheckProtocol want := "HTTP" if got != want { t.Fatalf("Health Check protocol for the API Target group did not equal expected value: %s; was: '%s'", want, got) @@ -1889,27 +2031,10 @@ func TestCreateNLB(t *testing.T) { return spec }, elbV2APIMocks: func(m *mocks.MockELBV2APIMockRecorder) { - m.CreateLoadBalancer(gomock.Eq(&elbv2.CreateLoadBalancerInput{ - Name: aws.String(elbName), - Scheme: aws.String("internet-facing"), - SecurityGroups: aws.StringSlice([]string{}), - Type: aws.String("network"), - Subnets: aws.StringSlice([]string{clusterSubnetID}), - Tags: []*elbv2.Tag{ - { - Key: aws.String("test"), - Value: aws.String("tag"), - }, - }, - })).Return(&elbv2.CreateLoadBalancerOutput{ - LoadBalancers: []*elbv2.LoadBalancer{ - { - LoadBalancerArn: aws.String(elbArn), - LoadBalancerName: aws.String(elbName), - Scheme: aws.String(string(infrav1.ELBSchemeInternetFacing)), - DNSName: aws.String(dns), - }, - }, + m.DescribeTargetGroups(gomock.Eq(&elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeTargetGroupsOutput{ + TargetGroups: []*elbv2.TargetGroup{}, }, nil) m.CreateTargetGroup(gomock.Eq(&elbv2.CreateTargetGroupInput{ Name: aws.String("name"), @@ -1933,7 +2058,7 @@ func TestCreateNLB(t *testing.T) { })).Return(&elbv2.CreateTargetGroupOutput{ TargetGroups: []*elbv2.TargetGroup{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), TargetGroupName: aws.String("name"), VpcId: aws.String(vpcID), HealthCheckEnabled: aws.Bool(true), @@ -1947,10 +2072,15 @@ func TestCreateNLB(t *testing.T) { }, }, }, nil) + m.DescribeListeners(gomock.Eq(&elbv2.DescribeListenersInput{ + LoadBalancerArn: aws.String(elbArn), + })).Return(&elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{}, + }, nil) m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ DefaultActions: []*elbv2.Action{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Type: aws.String(elbv2.ActionTypeEnumForward), }, }, @@ -1971,7 +2101,7 @@ func TestCreateNLB(t *testing.T) { }, }, nil) m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), Attributes: []*elbv2.TargetGroupAttribute{ { Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), @@ -1980,12 +2110,12 @@ func TestCreateNLB(t *testing.T) { }, })).Return(nil, nil) }, - check: func(t *testing.T, lb *infrav1.LoadBalancer, err error) { + check: func(t *testing.T, tgs []*elbv2.TargetGroup, _ []*elbv2.Listener, err error) { t.Helper() if err != nil { t.Fatalf("did not expect error: %v", err) } - got := *lb.ELBListeners[0].TargetGroup.HealthCheck.Protocol + got := *tgs[0].HealthCheckProtocol want := "HTTPS" if got != want { t.Fatalf("Health Check protocol for the API Target group did not equal expected value: %s; was: '%s'", want, got) @@ -2069,8 +2199,8 @@ func TestCreateNLB(t *testing.T) { } spec := tc.spec(*loadBalancerSpec) - lb, err := s.createLB(&spec, clusterScope.ControlPlaneLoadBalancer()) - tc.check(t, lb, err) + tgs, listeners, err := s.reconcileTargetGroupsAndListeners(spec.ARN, &spec, clusterScope.ControlPlaneLoadBalancer()) + tc.check(t, tgs, listeners, err) }) } } @@ -2082,6 +2212,7 @@ func TestReconcileV2LB(t *testing.T) { clusterSubnetID = "subnet-1" elbName = "bar-apiserver" elbArn = "arn::apiserver" + tgArn = "arn::target-group" vpcID = "vpc-id" az = "us-west-1a" ) @@ -2186,6 +2317,20 @@ func TestReconcileV2LB(t *testing.T) { }, }, }, nil) + m.DescribeTargetGroups(gomock.Eq(&elbv2.DescribeTargetGroupsInput{ + LoadBalancerArn: aws.String(elbArn), + })). + Return(&elbv2.DescribeTargetGroupsOutput{ + NextMarker: new(string), + TargetGroups: []*elbv2.TargetGroup{ + { + HealthCheckEnabled: aws.Bool(true), + LoadBalancerArns: []*string{aws.String(elbArn)}, + Matcher: &elbv2.Matcher{}, + TargetGroupArn: aws.String(tgArn), + TargetGroupName: aws.String("targetGroup"), + }}, + }, nil) m.ModifyLoadBalancerAttributes(&elbv2.ModifyLoadBalancerAttributesInput{ LoadBalancerArn: aws.String(elbArn), Attributes: []*elbv2.LoadBalancerAttribute{ @@ -2195,6 +2340,107 @@ func TestReconcileV2LB(t *testing.T) { }, }}). Return(&elbv2.ModifyLoadBalancerAttributesOutput{}, nil) + + m.CreateTargetGroup(helpers.PartialMatchCreateTargetGroupInput(t, &elbv2.CreateTargetGroupInput{ + HealthCheckEnabled: aws.Bool(true), + HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), + HealthCheckPort: aws.String(infrav1.DefaultAPIServerPortString), + HealthCheckProtocol: aws.String("TCP"), + HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), + HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), + // Note: this is treated as a prefix with the partial matcher. + Name: aws.String("apiserver-target"), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + Tags: []*elbv2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String("bar-apiserver"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/cluster/bar"), + Value: aws.String("owned"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/role"), + Value: aws.String("apiserver"), + }, + }, + UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), + VpcId: aws.String(vpcID), + })).Return(&elbv2.CreateTargetGroupOutput{ + TargetGroups: []*elbv2.TargetGroup{ + { + TargetGroupArn: aws.String(tgArn), + VpcId: aws.String(vpcID), + HealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerHealthThresholdCount), + UnhealthyThresholdCount: aws.Int64(infrav1.DefaultAPIServerUnhealthThresholdCount), + HealthCheckIntervalSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckIntervalSec), + HealthCheckTimeoutSeconds: aws.Int64(infrav1.DefaultAPIServerHealthCheckTimeoutSec), + }, + }, + }, nil) + + m.ModifyTargetGroupAttributes(gomock.Eq(&elbv2.ModifyTargetGroupAttributesInput{ + TargetGroupArn: aws.String(tgArn), + Attributes: []*elbv2.TargetGroupAttribute{ + { + Key: aws.String(infrav1.TargetGroupAttributeEnablePreserveClientIP), + Value: aws.String("false"), + }, + }, + })).Return(nil, nil) + + m.DescribeListeners(gomock.Eq(&elbv2.DescribeListenersInput{ + LoadBalancerArn: aws.String(elbArn), + })). + Return(&elbv2.DescribeListenersOutput{ + Listeners: []*elbv2.Listener{{ + DefaultActions: []*elbv2.Action{{ + TargetGroupArn: aws.String("arn::targetgroup"), + }}, + ListenerArn: aws.String("arn::listener"), + LoadBalancerArn: aws.String(elbArn), + }}, + }, nil) + m.CreateListener(gomock.Eq(&elbv2.CreateListenerInput{ + DefaultActions: []*elbv2.Action{ + { + TargetGroupArn: aws.String(tgArn), + Type: aws.String(elbv2.ActionTypeEnumForward), + }, + }, + LoadBalancerArn: aws.String(elbArn), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + Tags: []*elbv2.Tag{ + { + Key: aws.String("Name"), + Value: aws.String("bar-apiserver"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/cluster/bar"), + Value: aws.String("owned"), + }, + { + Key: aws.String("sigs.k8s.io/cluster-api-provider-aws/role"), + Value: aws.String("apiserver"), + }, + }, + })).Return(&elbv2.CreateListenerOutput{ + Listeners: []*elbv2.Listener{ + { + DefaultActions: []*elbv2.Action{ + { + TargetGroupArn: aws.String(tgArn), + Type: aws.String(elbv2.ActionTypeEnumForward), + }, + }, + ListenerArn: aws.String("listener::arn"), + Port: aws.Int64(infrav1.DefaultAPIServerPort), + Protocol: aws.String("TCP"), + }, + }}, nil) m.DescribeLoadBalancerAttributes(&elbv2.DescribeLoadBalancerAttributesInput{LoadBalancerArn: aws.String(elbArn)}).Return( &elbv2.DescribeLoadBalancerAttributesOutput{ Attributes: []*elbv2.LoadBalancerAttribute{ @@ -2207,9 +2453,7 @@ func TestReconcileV2LB(t *testing.T) { Value: aws.String(string(infrav1.ResourceLifecycleOwned)), }, }, - }, - nil, - ) + }, nil) m.DescribeTags(&elbv2.DescribeTagsInput{ResourceArns: []*string{aws.String(elbArn)}}).Return( &elbv2.DescribeTagsOutput{ TagDescriptions: []*elbv2.TagDescription{ @@ -2223,9 +2467,7 @@ func TestReconcileV2LB(t *testing.T) { }, }, }, - }, - nil, - ) + }, nil) // Avoid the need to sort the AddTagsInput.Tags slice m.AddTags(gomock.AssignableToTypeOf(&elbv2.AddTagsInput{})).Return(&elbv2.AddTagsOutput{}, nil) @@ -2685,6 +2927,7 @@ func TestDeleteNLB(t *testing.T) { clusterName := "bar" elbName := "bar-apiserver" elbArn := "apiserver::arn" + tgArn := "arn::target-group" tests := []struct { name string elbv2ApiMock func(m *mocks.MockELBV2APIMockRecorder) @@ -2794,11 +3037,11 @@ func TestDeleteNLB(t *testing.T) { m.DescribeTargetGroups(&elbv2.DescribeTargetGroupsInput{LoadBalancerArn: aws.String(elbArn)}).Return(&elbv2.DescribeTargetGroupsOutput{ TargetGroups: []*elbv2.TargetGroup{ { - TargetGroupArn: aws.String("target-group::arn"), + TargetGroupArn: aws.String(tgArn), }, }, }, nil) - m.DeleteTargetGroup(&elbv2.DeleteTargetGroupInput{TargetGroupArn: aws.String("target-group::arn")}).Return(&elbv2.DeleteTargetGroupOutput{}, nil) + m.DeleteTargetGroup(&elbv2.DeleteTargetGroupInput{TargetGroupArn: aws.String(tgArn)}).Return(&elbv2.DeleteTargetGroupOutput{}, nil) // delete the load balancer m.DeleteLoadBalancer(&elbv2.DeleteLoadBalancerInput{LoadBalancerArn: aws.String(elbArn)}).Return( diff --git a/test/helpers/matchers.go b/test/helpers/matchers.go new file mode 100644 index 0000000000..202ae22c27 --- /dev/null +++ b/test/helpers/matchers.go @@ -0,0 +1,66 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package helpers + +import ( + "fmt" + "strings" + "testing" + + "github.com/aws/aws-sdk-go/service/elbv2" + "github.com/golang/mock/gomock" +) + +// PartialMatchCreateTargetGroupInput matches a partial CreateTargetGroupInput struct based on fuzzy matching rules. +func PartialMatchCreateTargetGroupInput(t *testing.T, i *elbv2.CreateTargetGroupInput) gomock.Matcher { + t.Helper() + return &createTargetGroupInputPartialMatcher{ + in: i, + t: t, + } +} + +// createTargetGroupInputPartialMatcher conforms to the gomock.Matcher interface in order to implement a match against a partial +// CreateTargetGroupInput expected value. +// In particular, the TargetGroupName expected value is used as a prefix, in order to support generated names. +type createTargetGroupInputPartialMatcher struct { + in *elbv2.CreateTargetGroupInput + t *testing.T +} + +func (m *createTargetGroupInputPartialMatcher) Matches(x interface{}) bool { + actual, ok := x.(*elbv2.CreateTargetGroupInput) + if !ok { + return false + } + + // Check for a perfect match across all fields first. + eq := gomock.Eq(m.in).Matches(actual) + + if !eq && (actual.Name != nil && m.in.Name != nil) { + // If the actual name is prefixed with the expected value, then it matches + if (*actual.Name != *m.in.Name) && strings.HasPrefix(*actual.Name, *m.in.Name) { + return true + } + } + + return eq +} + +func (m *createTargetGroupInputPartialMatcher) String() string { + return fmt.Sprintf("%v (%T)", m.in, m.in) +}