diff --git a/pkg/auth/authenticator.go b/pkg/auth/authenticator.go index 213bbdf..3f3c6af 100644 --- a/pkg/auth/authenticator.go +++ b/pkg/auth/authenticator.go @@ -91,9 +91,6 @@ const ( // CerberusReasonEmptySourceIp means that source ip is empty CerberusReasonEmptySourceIp CerberusReason = "source-ip-empty" - // CerberusReasonBadIpList means that there is no valid public ip for validation - CerberusReasonNoValidIp CerberusReason = "no-valid-ip" - // CerberusReasonBadIpList means that ip list items are not in valid patterns which is CIDR notation of the networks CerberusReasonBadIpList CerberusReason = "bad-ip-list" @@ -297,11 +294,7 @@ func (a *Authenticator) TestAccess(request *Request, wsvc ServicesCacheEntry) (b // Check x-forwarded-for header against IP allow list if len(ac.Spec.IpAllowList) > 0 { - publicIp := lastPublicIp(ipList) - if publicIp == "" { - return false, CerberusReasonNoValidIp, newExtraHeaders - } - ipAllowed, err := checkIP(publicIp, ac.Spec.IpAllowList) + ipAllowed, err := checkIP(ipList, ac.Spec.IpAllowList) if err != nil { return false, CerberusReasonBadIpList, newExtraHeaders } @@ -409,31 +402,21 @@ func NewAuthenticator(logger logr.Logger) (*Authenticator, error) { return &a, nil } -// lastPublicIp will identify the last valid public IP address within the list of IPs -// will return an error if it cannot find any valid public IP addresses in the input list. -func lastPublicIp(ips []string) string { - for i := len(ips) - 1; i >= 0; i-- { - clientIP := net.ParseIP(ips[i]) - if clientIP != nil && !clientIP.IsPrivate() { - return ips[i] - } - } - return "" -} - // checkIP checks if given ip is a member of given CIDR networks or not // ipAllowList should be CIDR notation of the networks or net.ParseError will be retuned -func checkIP(ip string, ipAllowList []string) (bool, error) { - clientIP := net.ParseIP(ip) - - for _, AllowedRangeIP := range ipAllowList { - _, subnet, err := net.ParseCIDR(AllowedRangeIP) - if err != nil { - return false, err - } +func checkIP(ips []string, ipAllowList []string) (bool, error) { + for _, ip := range ips { + clientIP := net.ParseIP(ip) + + for _, AllowedRangeIP := range ipAllowList { + _, subnet, err := net.ParseCIDR(AllowedRangeIP) + if err != nil { + return false, err + } - if subnet.Contains(clientIP) { - return true, nil + if subnet.Contains(clientIP) { + return true, nil + } } } return false, nil diff --git a/pkg/auth/authenticator_test.go b/pkg/auth/authenticator_test.go index f004544..600dbb0 100644 --- a/pkg/auth/authenticator_test.go +++ b/pkg/auth/authenticator_test.go @@ -35,7 +35,7 @@ func BenchmarkCheckIPWithLargeInput(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - _, _ = checkIP(testIP, ipAllowList) + _, _ = checkIP([]string{testIP}, ipAllowList) } }