diff --git a/pkg/provider/azure_loadbalancer_accesscontrol_test.go b/pkg/provider/azure_loadbalancer_accesscontrol_test.go index cc63267c23..0112c7fd84 100644 --- a/pkg/provider/azure_loadbalancer_accesscontrol_test.go +++ b/pkg/provider/azure_loadbalancer_accesscontrol_test.go @@ -2242,6 +2242,11 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { WithDestination(append(azureFx.LoadBalancer().IPv6Addresses(), "baz")...). // Should keep baz but clean the rest Build(), + azureFx.DenyAllSecurityRule(iputil.IPv4). + WithPriority(4095). + WithDestination(append(azureFx.LoadBalancer().IPv4Addresses(), "5.5.5.5/32")...). + Build(), + { Name: ptr.To("foo"), SecurityRulePropertiesFormat: &network.SecurityRulePropertiesFormat{ @@ -2338,6 +2343,11 @@ func TestCloud_reconcileSecurityGroup(t *testing.T) { WithPriority(3000). WithDestination("foo"). // should keep foo Build(), + + azureFx.DenyAllSecurityRule(iputil.IPv4). + WithPriority(4095). + WithDestination("5.5.5.5/32"). + Build(), ) testutil.ExpectExactSecurityRules(t, &properties, rules) diff --git a/pkg/provider/loadbalancer/securitygroup/securitygroup.go b/pkg/provider/loadbalancer/securitygroup/securitygroup.go index 2144b0f31c..01227960ff 100644 --- a/pkg/provider/loadbalancer/securitygroup/securitygroup.go +++ b/pkg/provider/loadbalancer/securitygroup/securitygroup.go @@ -327,6 +327,24 @@ func (helper *RuleHelper) RemoveDestinationFromRules( func (helper *RuleHelper) removeDestinationFromRule(rule *network.SecurityRule, prefixes []string, retainDstPorts []int32) error { logger := helper.logger.WithName("removeDestinationFromRule"). WithValues("security-rule-name", rule.Name) + + var ( + prefixIndex = fnutil.IndexSet(prefixes) // Used to check whether the prefix should be removed. + currentPrefixes = ListDestinationPrefixes(rule) + + expectedPrefixes = fnutil.RemoveIf(func(p string) bool { return prefixIndex[p] }, currentPrefixes) // The prefixes to keep. + targetPrefixes = fnutil.Intersection(currentPrefixes, prefixes) // The prefixes to remove. + ) + + // Clean DenyAll rule + if rule.Access == network.SecurityRuleAccessDeny && len(retainDstPorts) == 0 { + // Update the prefixes + rule.DestinationAddressPrefix = nil + rule.DestinationAddressPrefixes = ptr.To(NormalizeSecurityRuleAddressPrefixes(expectedPrefixes)) + return nil + } + + // Clean Allow rule currentPorts, err := ListDestinationPortRanges(rule) if err != nil { // Skip the rule with invalid destination port ranges. @@ -334,14 +352,8 @@ func (helper *RuleHelper) removeDestinationFromRule(rule *network.SecurityRule, logger.Info("Skip because it contains `*` or port-ranges as destination port ranges.") return nil } - var ( - prefixIndex = fnutil.IndexSet(prefixes) // Used to check whether the prefix should be removed. - currentPrefixes = ListDestinationPrefixes(rule) - - expectedPrefixes = fnutil.RemoveIf(func(p string) bool { return prefixIndex[p] }, currentPrefixes) // The prefixes to keep. - targetPrefixes = fnutil.Intersection(currentPrefixes, prefixes) // The prefixes to remove. - expectedPorts = fnutil.Intersection(currentPorts, retainDstPorts) // The ports to keep. + expectedPorts = fnutil.Intersection(currentPorts, retainDstPorts) // The ports to keep. ) if len(targetPrefixes) == 0 || len(currentPorts) == len(expectedPorts) { diff --git a/tests/e2e/network/network_security_group.go b/tests/e2e/network/network_security_group.go index 22fc16c1c6..95d18a0fbb 100644 --- a/tests/e2e/network/network_security_group.go +++ b/tests/e2e/network/network_security_group.go @@ -759,32 +759,34 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( }) By("Creating service 2", func() { + + joinIPsAsString := func(ips []netip.Addr) string { + var s []string + for _, ip := range ips { + s = append(s, ip.String()) + } + return strings.Join(s, ",") + } + var ( labels = map[string]string{ "app": Deployment2Name, } - annotations = map[string]string{} - ports = []v1.ServicePort{{ + annotations = map[string]string{ + "service.beta.kubernetes.io/azure-load-balancer-ipv4": joinIPsAsString(svc1IPv4s), + "service.beta.kubernetes.io/azure-load-balancer-ipv6": joinIPsAsString(svc1IPv6s), + } + ports = []v1.ServicePort{{ Port: app2Port, TargetPort: intstr.FromInt32(app2Port), }} ) - var ip netip.Addr - if len(svc1IPv4s) > 0 { - ip = svc1IPv4s[0] - } - if len(svc1IPv6s) > 0 { - ip = svc1IPv6s[0] - } - rv := createAndExposeDefaultServiceWithAnnotation(k8sClient, azureClient.IPFamily, Service2Name, namespace.Name, labels, annotations, ports, func(svc *v1.Service) error { - svc.Spec.LoadBalancerIP = ip.String() - return nil - }) - svc2IPs = mustParseIPs(derefSliceOfStringPtr(rv)) - logger.Info("Created the second LoadBalancer service", "svc-name", Service2Name, "IPs", svc2IPs) - Expect(svc2IPs).To(HaveLen(1)) - Expect(svc2IPs[0]).To(Equal(ip)) + rv := createAndExposeDefaultServiceWithAnnotation(k8sClient, azureClient.IPFamily, Service2Name, namespace.Name, labels, annotations, ports) + svc2IPv4s, svc2IPv6s := groupIPsByFamily(mustParseIPs(derefSliceOfStringPtr(rv))) + logger.Info("Created the second LoadBalancer service", "svc-name", Service2Name, "v4-IPs", svc2IPv4s, "v6-IPs", svc2IPv6s) + Expect(svc2IPv4s).To(Equal(svc1IPv4s)) + Expect(svc2IPv6s).To(Equal(svc1IPv6s)) }) var validator *SecurityGroupValidator @@ -806,13 +808,13 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( if len(svc1IPv4s) > 0 { Expect( validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, svc1IPv4s, expectedDstPorts), - ).To(BeTrue(), "Should not have a rule for allowing IPv4 traffic from Internet") + ).To(BeTrue(), "Should have a rule for allowing IPv4 traffic from Internet") } if len(svc1IPv6s) > 0 { Expect( validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, svc1IPv6s, expectedDstPorts), - ).To(BeTrue(), "Should not have a rule for allowing IPv6 traffic from Internet") + ).To(BeTrue(), "Should have a rule for allowing IPv6 traffic from Internet") } }) @@ -824,7 +826,7 @@ var _ = Describe("Network security group", Label(utils.TestSuiteLabelNSG), func( By("Checking if the rule for allowing traffic from Internet exists") Expect( validator.HasExactAllowRule(expectedProtocol, []string{"Internet"}, svc2IPs, expectedDstPorts), - ).To(BeTrue(), "Should not have a rule for allowing traffic from Internet") + ).To(BeTrue(), "Should have a rule for allowing traffic from Internet") }) }) }) @@ -877,11 +879,20 @@ func (v *SecurityGroupValidator) HasDenyAllRuleForDestination(dstAddresses []net } func SecurityGroupNotHasRuleForDestination(nsg *aznetwork.SecurityGroup, dstAddresses []netip.Addr) bool { + logger := GinkgoLogr.WithName("SecurityGroupNotHasRuleForDestination"). + WithValues("nsg-name", nsg.Name). + WithValues("dst-addresses", dstAddresses) + if len(dstAddresses) == 0 { + logger.Info("skip") + return true + } + logger.Info("checking") dsts := sets.NewString() for _, ip := range dstAddresses { dsts.Insert(ip.String()) } for _, rule := range nsg.Properties.SecurityRules { + logger.Info("checking rule", "rule-name", rule.Name, "rule", rule) if rule.Properties.DestinationAddressPrefix != nil && dsts.Has(*rule.Properties.DestinationAddressPrefix) { return false }