diff --git a/controllers/gce/controller/fakes.go b/controllers/gce/controller/fakes.go index dc8a80f586..11866b3685 100644 --- a/controllers/gce/controller/fakes.go +++ b/controllers/gce/controller/fakes.go @@ -67,7 +67,7 @@ func NewFakeClusterManager(clusterName, firewallName string) *fakeClusterManager testDefaultBeNodePort, namer, ) - frPool := firewalls.NewFirewallPool(firewalls.NewFakeFirewallRules(), namer) + frPool := firewalls.NewFirewallPool(firewalls.NewFakeFirewallsProvider(namer), namer) cm := &ClusterManager{ ClusterNamer: namer, instancePool: nodePool, diff --git a/controllers/gce/firewalls/fakes.go b/controllers/gce/firewalls/fakes.go index 73f1a56f1b..dc0723e0ca 100644 --- a/controllers/gce/firewalls/fakes.go +++ b/controllers/gce/firewalls/fakes.go @@ -18,6 +18,7 @@ package firewalls import ( "fmt" + "strconv" compute "google.golang.org/api/compute/v1" netset "k8s.io/kubernetes/pkg/util/net/sets" @@ -25,81 +26,76 @@ import ( "k8s.io/ingress/controllers/gce/utils" ) -type fakeFirewallRules struct { - fw []*compute.Firewall - namer utils.Namer +type fakeFirewallsProvider struct { + fw map[string]*compute.Firewall + namer *utils.Namer } -func (f *fakeFirewallRules) GetFirewall(name string) (*compute.Firewall, error) { - for _, rule := range f.fw { - if rule.Name == name { - return rule, nil - } +// NewFakeFirewallsProvider creates a fake for firewall rules. +func NewFakeFirewallsProvider(namer *utils.Namer) *fakeFirewallsProvider { + return &fakeFirewallsProvider{ + fw: make(map[string]*compute.Firewall), + namer: namer, } - return nil, fmt.Errorf("firewall rule %v not found", name) } -func (f *fakeFirewallRules) CreateFirewall(name, msgTag string, srcRange netset.IPNet, ports []int64, hosts []string) error { +func (f *fakeFirewallsProvider) GetFirewall(prefixedName string) (*compute.Firewall, error) { + rule, exists := f.fw[prefixedName] + if exists { + return rule, nil + } + return nil, utils.FakeGoogleAPINotFoundErr() +} + +func (f *fakeFirewallsProvider) CreateFirewall(name, msgTag string, srcRange netset.IPNet, ports []int64, hosts []string) error { + prefixedName := f.namer.FrName(name) strPorts := []string{} for _, p := range ports { - strPorts = append(strPorts, fmt.Sprintf("%v", p)) + strPorts = append(strPorts, strconv.FormatInt(p, 10)) + } + if _, exists := f.fw[prefixedName]; exists { + return fmt.Errorf("firewall rule %v already exists", prefixedName) } - f.fw = append(f.fw, &compute.Firewall{ + + f.fw[prefixedName] = &compute.Firewall{ // To accurately mimic the cloudprovider we need to add the k8s-fw // prefix to the given rule name. - Name: f.namer.FrName(name), + Name: prefixedName, SourceRanges: srcRange.StringSlice(), Allowed: []*compute.FirewallAllowed{{Ports: strPorts}}, - }) + } return nil } -func (f *fakeFirewallRules) DeleteFirewall(name string) error { - firewalls := []*compute.Firewall{} - exists := false +func (f *fakeFirewallsProvider) DeleteFirewall(name string) error { // We need the full name for the same reason as CreateFirewall. - name = f.namer.FrName(name) - for _, rule := range f.fw { - if rule.Name == name { - exists = true - continue - } - firewalls = append(firewalls, rule) - } + prefixedName := f.namer.FrName(name) + _, exists := f.fw[prefixedName] if !exists { - return fmt.Errorf("failed to find health check %v", name) + return utils.FakeGoogleAPINotFoundErr() } - f.fw = firewalls + + delete(f.fw, prefixedName) return nil } -func (f *fakeFirewallRules) UpdateFirewall(name, msgTag string, srcRange netset.IPNet, ports []int64, hosts []string) error { - var exists bool +func (f *fakeFirewallsProvider) UpdateFirewall(name, msgTag string, srcRange netset.IPNet, ports []int64, hosts []string) error { strPorts := []string{} for _, p := range ports { - strPorts = append(strPorts, fmt.Sprintf("%v", p)) + strPorts = append(strPorts, strconv.FormatInt(p, 10)) } - // To accurately mimic the cloudprovider we need to add the k8s-fw - // prefix to the given rule name. - name = f.namer.FrName(name) - for i := range f.fw { - if f.fw[i].Name == name { - exists = true - f.fw[i] = &compute.Firewall{ - Name: name, - SourceRanges: srcRange.StringSlice(), - Allowed: []*compute.FirewallAllowed{{Ports: strPorts}}, - } - } - } - if exists { - return nil + // We need the full name for the same reason as CreateFirewall. + prefixedName := f.namer.FrName(name) + _, exists := f.fw[prefixedName] + if !exists { + return fmt.Errorf("update failed for rule %v, srcRange %v ports %v, rule not found", prefixedName, srcRange, ports) } - return fmt.Errorf("update failed for rule %v, srcRange %v ports %v, rule not found", name, srcRange, ports) -} -// NewFakeFirewallRules creates a fake for firewall rules. -func NewFakeFirewallRules() *fakeFirewallRules { - return &fakeFirewallRules{fw: []*compute.Firewall{}, namer: utils.Namer{}} + f.fw[prefixedName] = &compute.Firewall{ + Name: name, + SourceRanges: srcRange.StringSlice(), + Allowed: []*compute.FirewallAllowed{{Ports: strPorts}}, + } + return nil } diff --git a/controllers/gce/firewalls/firewalls.go b/controllers/gce/firewalls/firewalls.go index 1d40b9aa06..0e01529dca 100644 --- a/controllers/gce/firewalls/firewalls.go +++ b/controllers/gce/firewalls/firewalls.go @@ -75,9 +75,14 @@ func (fr *FirewallRules) Sync(nodePorts []int64, nodeNames []string) error { existingPorts.Insert(p) } } - if requiredPorts.Equal(existingPorts) { + + requiredCIDRs := sets.NewString(l7SrcRanges...) + existingCIDRs := sets.NewString(rule.SourceRanges...) + + if requiredPorts.Equal(existingPorts) && requiredCIDRs.Equal(existingCIDRs) { return nil } + glog.V(3).Infof("Firewall rule %v already exists, updating nodeports %v", name, nodePorts) return fr.cloud.UpdateFirewall(suffix, "GCE L7 firewall rule", fr.srcRanges, nodePorts, nodeNames) } @@ -85,7 +90,12 @@ func (fr *FirewallRules) Sync(nodePorts []int64, nodeNames []string) error { // Shutdown shuts down this firewall rules manager. func (fr *FirewallRules) Shutdown() error { glog.Infof("Deleting firewall rule with suffix %v", fr.namer.FrSuffix()) - return fr.cloud.DeleteFirewall(fr.namer.FrSuffix()) + err := fr.cloud.DeleteFirewall(fr.namer.FrSuffix()) + if err != nil && utils.IsHTTPErrorCode(err, 404) { + glog.Infof("Firewall with suffix %v didn't exist at Shutdown", fr.namer.FrSuffix()) + return nil + } + return err } // GetFirewall just returns the firewall object corresponding to the given name. diff --git a/controllers/gce/firewalls/firewalls_test.go b/controllers/gce/firewalls/firewalls_test.go new file mode 100644 index 0000000000..5abd9b36ec --- /dev/null +++ b/controllers/gce/firewalls/firewalls_test.go @@ -0,0 +1,95 @@ +package firewalls + +import ( + "strconv" + "testing" + + "k8s.io/apimachinery/pkg/util/sets" + "k8s.io/ingress/controllers/gce/utils" + netset "k8s.io/kubernetes/pkg/util/net/sets" +) + +func TestSyncFirewallPool(t *testing.T) { + namer := utils.NewNamer("ABC", "XYZ") + fwp := NewFakeFirewallsProvider(namer) + fp := NewFirewallPool(fwp, namer) + ruleName := namer.FrName(namer.FrSuffix()) + + // Test creating a firewall rule via Sync + nodePorts := []int64{80, 443, 3000} + nodes := []string{"node-a", "node-b", "node-c"} + err := fp.Sync(nodePorts, nodes) + if err != nil { + t.Errorf("unexpected err when syncing firewall, err: %v", err) + } + verifyFirewallRule(fwp, ruleName, nodePorts, nodes, l7SrcRanges, t) + + // Sync to fewer ports + nodePorts = []int64{80, 443} + err = fp.Sync(nodePorts, nodes) + if err != nil { + t.Errorf("unexpected err when syncing firewall, err: %v", err) + } + verifyFirewallRule(fwp, ruleName, nodePorts, nodes, l7SrcRanges, t) + + all := "0.0.0.0/0" + srcRanges, _ := netset.ParseIPNets(all) + err = fwp.UpdateFirewall(namer.FrSuffix(), "", srcRanges, nodePorts, nodes) + if err != nil { + t.Errorf("failed to update firewall rule, err: %v", err) + } + verifyFirewallRule(fwp, ruleName, nodePorts, nodes, []string{all}, t) + + // Run Sync and expect l7 src ranges to be returned + err = fp.Sync(nodePorts, nodes) + if err != nil { + t.Errorf("unexpected err when syncing firewall, err: %v", err) + } + verifyFirewallRule(fwp, ruleName, nodePorts, nodes, l7SrcRanges, t) + + // Add node and expect firwall to change nodes list + nodes = []string{"node-a", "node-b", "node-c", "node-d"} + err = fp.Sync(nodePorts, nodes) + if err != nil { + t.Errorf("unexpected err when syncing firewall, err: %v", err) + } + verifyFirewallRule(fwp, ruleName, nodePorts, nodes, l7SrcRanges, t) + + // Remove all ports and expect firewall rule to disappear + nodePorts = []int64{} + err = fp.Sync(nodePorts, nodes) + if err != nil { + t.Errorf("unexpected err when syncing firewall, err: %v", err) + } + + err = fp.Shutdown() + if err != nil { + t.Errorf("unexpected err when deleting firewall, err: %v", err) + } +} + +func verifyFirewallRule(fwp *fakeFirewallsProvider, ruleName string, expectedPorts []int64, expectedNodes, expectedCIDRs []string, t *testing.T) { + var strPorts []string + for _, v := range expectedPorts { + strPorts = append(strPorts, strconv.FormatInt(v, 10)) + } + + // Verify firewall rule was created + f, err := fwp.GetFirewall(ruleName) + if err != nil { + t.Errorf("could not retrieve firewall via cloud api, err %v", err) + } + + // Verify firwall rule has correct ports + if !sets.NewString(f.Allowed[0].Ports...).Equal(sets.NewString(strPorts...)) { + t.Errorf("allowed ports doesn't equal expected ports, Actual: %v, Expected: %v", f.Allowed[0].Ports, strPorts) + } + + // Verify firwall rule has correct CIDRs + if !sets.NewString(f.SourceRanges...).Equal(sets.NewString(expectedCIDRs...)) { + t.Errorf("source CIDRs doesn't equal expected CIDRs. Actual: %v, Expected: %v", f.SourceRanges, expectedCIDRs) + } + + // Verify firwall rule has correct nodes + // TODO: Check host tags are updated +} diff --git a/controllers/gce/utils/utils.go b/controllers/gce/utils/utils.go index 288258d9c9..f59d7b259d 100644 --- a/controllers/gce/utils/utils.go +++ b/controllers/gce/utils/utils.go @@ -312,6 +312,11 @@ func (g GCEURLMap) PutDefaultBackend(d *compute.BackendService) { } } +// FakeNotFoundErr creates a NotFound error with type googleapi.Error +func FakeGoogleAPINotFoundErr() *googleapi.Error { + return &googleapi.Error{Code: 404} +} + // IsHTTPErrorCode checks if the given error matches the given HTTP Error code. // For this to work the error must be a googleapi Error. func IsHTTPErrorCode(err error, code int) bool {