diff --git a/core/pkg/ingress/annotations/ipwhitelist/main.go b/core/pkg/ingress/annotations/ipwhitelist/main.go index 0e44ff9d64..d42f0fad3e 100644 --- a/core/pkg/ingress/annotations/ipwhitelist/main.go +++ b/core/pkg/ingress/annotations/ipwhitelist/main.go @@ -56,8 +56,9 @@ func (a ipwhitelist) Parse(ing *extensions.Ingress) (interface{}, error) { sort.Strings(defBackend.WhitelistSourceRange) val, err := parser.GetStringAnnotation(whitelist, ing) - if err != nil { - return &SourceRange{CIDR: defBackend.WhitelistSourceRange}, err + // A missing annotation is not a problem, just use the default + if err == ing_errors.ErrMissingAnnotations { + return &SourceRange{CIDR: defBackend.WhitelistSourceRange}, nil } values := strings.Split(val, ",") diff --git a/core/pkg/ingress/annotations/ipwhitelist/main_test.go b/core/pkg/ingress/annotations/ipwhitelist/main_test.go index d3f60adc1a..190a36f896 100644 --- a/core/pkg/ingress/annotations/ipwhitelist/main_test.go +++ b/core/pkg/ingress/annotations/ipwhitelist/main_test.go @@ -25,6 +25,7 @@ import ( "k8s.io/kubernetes/pkg/util/intstr" "k8s.io/ingress/core/pkg/ingress/defaults" + "k8s.io/ingress/core/pkg/ingress/errors" ) func buildIngress() *extensions.Ingress { @@ -63,10 +64,11 @@ func buildIngress() *extensions.Ingress { } type mockBackend struct { + defaults.Backend } func (m mockBackend) GetDefaultBackend() defaults.Backend { - return defaults.Backend{} + return m.Backend } func TestParseAnnotations(t *testing.T) { @@ -105,16 +107,21 @@ func TestParseAnnotations(t *testing.T) { t.Errorf("expected error parsing an invalid cidr") } + if !errors.IsLocationDenied(err) { + t.Errorf("expected LocationDenied error: %+v", err) + } + delete(data, whitelist) - ing.SetAnnotations(data) i, err = p.Parse(ing) + + if err != nil { + t.Errorf("unexpected error when no annotation present: %v", err) + } + sr, ok = i.(*SourceRange) if !ok { t.Errorf("expected a SourceRange type") } - if err == nil { - t.Errorf("expected error parsing an invalid cidr") - } if !strsEquals(sr.CIDR, []string{}) { t.Errorf("expected empty CIDR but %v returned", sr.CIDR) } @@ -140,6 +147,85 @@ func TestParseAnnotations(t *testing.T) { } } +// Test that when we have a whitelist set on the Backend that is used when we +// don't have the annotation +func TestParseAnnotationsWithDefaultConfig(t *testing.T) { + // TODO: convert test cases to tables + ing := buildIngress() + + mockBackend := mockBackend{} + mockBackend.Backend.WhitelistSourceRange = []string{"4.4.4.0/24", "1.2.3.4/32"} + testNet := "10.0.0.0/24" + enet := []string{testNet} + + data := map[string]string{} + data[whitelist] = testNet + ing.SetAnnotations(data) + + expected := &SourceRange{ + CIDR: enet, + } + + p := NewParser(mockBackend) + + i, err := p.Parse(ing) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + sr, ok := i.(*SourceRange) + if !ok { + t.Errorf("expected a SourceRange type") + } + + if !reflect.DeepEqual(sr, expected) { + t.Errorf("expected %v but returned %s", sr, expected) + } + + data[whitelist] = "www" + _, err = p.Parse(ing) + if err == nil { + t.Errorf("expected error parsing an invalid cidr") + } + if !errors.IsLocationDenied(err) { + t.Errorf("expected LocationDenied error: %+v", err) + } + + delete(data, whitelist) + i, err = p.Parse(ing) + + if err != nil { + t.Errorf("unexpected error when no annotation present: %v", err) + } + + sr, ok = i.(*SourceRange) + if !ok { + t.Errorf("expected a SourceRange type") + } + if !strsEquals(sr.CIDR, mockBackend.WhitelistSourceRange) { + t.Errorf("expected fallback CIDR but %v returned", sr.CIDR) + } + + i, _ = p.Parse(&extensions.Ingress{}) + sr, ok = i.(*SourceRange) + if !ok { + t.Errorf("expected a SourceRange type") + } + if !strsEquals(sr.CIDR, mockBackend.WhitelistSourceRange) { + t.Errorf("expected fallback CIDR but %v returned", sr.CIDR) + } + + data[whitelist] = "2.2.2.2/32,1.1.1.1/32,3.3.3.0/24" + i, _ = p.Parse(ing) + sr, ok = i.(*SourceRange) + if !ok { + t.Errorf("expected a SourceRange type") + } + ecidr := []string{"1.1.1.1/32", "2.2.2.2/32", "3.3.3.0/24"} + if !strsEquals(sr.CIDR, ecidr) { + t.Errorf("Expected %v CIDR but %v returned", ecidr, sr.CIDR) + } +} + func strsEquals(a, b []string) bool { if len(a) != len(b) { return false