Skip to content

Commit

Permalink
fix(aws/inboundIp): revoke and authorize the whole group
Browse files Browse the repository at this point in the history
Bug: given only 1 IP in the SG, 1 is revoked, 1 is authorized

Given one IP in the SG '92.143.36.217'
Run `gf aws inboundIp`
Result: the SG now contains one IP '2a01:cb14:1686::'
Expected result:
  the SG contains both IPs '2a01:cb14:1686::', and '92.143.36.217'

Why ?
  - IPv4 is skipped, because it belongs to the SG
  - IPv6 is added to the group
  - all rules are revoked
  - the group is authorized

Fix #1:
  Revoke and authorize the whole group.
  But now the group is revoked+authorized even to the identical

Fix #2:
  Add comparison methods for []IpRange []Ipv6Range []IpPermission
  Add tests
  • Loading branch information
gforien-externe committed Oct 25, 2024
1 parent 99f4cc7 commit c9586fd
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 12 deletions.
68 changes: 68 additions & 0 deletions internal/aws/ip_permissions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package aws

import (
"log"
"strings"

"github.com/aws/aws-sdk-go-v2/service/ec2/types"
)

func EqualsIpPerms(a []types.IpPermission, b []types.IpPermission) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
if len(a) != len(b) {
return false
}
for i := range a {
if !EqualsString(a[i].IpProtocol, b[i].IpProtocol) {
return false
}
if !EqualsIpRange(a[i].IpRanges, b[i].IpRanges) {
log.Default().Print("IpRanges not equal")
return false
}
if !EqualsIpv6Range(a[i].Ipv6Ranges, b[i].Ipv6Ranges) {
log.Default().Print("Ipv6Ranges not equal")
return false
}
}
return true
}

func EqualsIpRange(a []types.IpRange, b []types.IpRange) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
if len(a) != len(b) {
return false
}
for i := range a {
if !EqualsString(a[i].CidrIp, b[i].CidrIp) {
return false
}
}
return true
}

func EqualsIpv6Range(a []types.Ipv6Range, b []types.Ipv6Range) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
if len(a) != len(b) {
return false
}
for i := range a {
if !EqualsString(a[i].CidrIpv6, b[i].CidrIpv6) {
return false
}
}
return true
}

func EqualsString(a *string, b *string) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
return strings.EqualFold(*a, *b)
}
212 changes: 212 additions & 0 deletions internal/aws/ip_permissions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
package aws

import (
"testing"

"github.com/aws/aws-sdk-go-v2/service/ec2/types"
)

func TestEqualsIpPerms(t *testing.T) {
tests := []struct {
name string
a []types.IpPermission
b []types.IpPermission
expected bool
}{
{
name: "Nil slices",
a: nil,
b: nil,
expected: true,
},
{
name: "One nil slice",
a: nil,
b: []types.IpPermission{},
expected: false,
},
{
name: "Different lengths",
a: []types.IpPermission{{IpProtocol: stringPointer("tcp")}},
b: []types.IpPermission{},
expected: false,
},
{
name: "Equal slices",
a: []types.IpPermission{{IpProtocol: stringPointer("tcp")}},
b: []types.IpPermission{{IpProtocol: stringPointer("tcp")}},
expected: true,
},
{
name: "Different protocols",
a: []types.IpPermission{{IpProtocol: stringPointer("tcp")}},
b: []types.IpPermission{{IpProtocol: stringPointer("udp")}},
expected: false,
},
{
name: "Different IP ranges",
a: []types.IpPermission{
{IpProtocol: stringPointer("tcp"), IpRanges: []types.IpRange{{CidrIp: stringPointer("192.168.1.0/24")}}},
},
b: []types.IpPermission{
{IpProtocol: stringPointer("tcp"), IpRanges: []types.IpRange{{CidrIp: stringPointer("10.0.0.0/24")}}},
},
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := EqualsIpPerms(tt.a, tt.b)
if got != tt.expected {
t.Errorf("EqualsIpPerms() = %v, want %v", got, tt.expected)
}
})
}
}

func TestEqualsIpRange(t *testing.T) {
tests := []struct {
name string
a []types.IpRange
b []types.IpRange
expected bool
}{
{
name: "Nil slices",
a: nil,
b: nil,
expected: true,
},
{
name: "One nil slice",
a: nil,
b: []types.IpRange{},
expected: false,
},
{
name: "Different lengths",
a: []types.IpRange{{CidrIp: stringPointer("192.168.1.0/24")}},
b: []types.IpRange{},
expected: false,
},
{
name: "Equal slices",
a: []types.IpRange{{CidrIp: stringPointer("192.168.1.0/24")}},
b: []types.IpRange{{CidrIp: stringPointer("192.168.1.0/24")}},
expected: true,
},
{
name: "Different CIDR IPs",
a: []types.IpRange{{CidrIp: stringPointer("192.168.1.0/24")}},
b: []types.IpRange{{CidrIp: stringPointer("10.0.0.0/24")}},
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := EqualsIpRange(tt.a, tt.b)
if got != tt.expected {
t.Errorf("EqualsIpRange() = %v, want %v", got, tt.expected)
}
})
}
}

func TestEqualsIpv6Range(t *testing.T) {
tests := []struct {
name string
a []types.Ipv6Range
b []types.Ipv6Range
expected bool
}{
{
name: "Nil slices",
a: nil,
b: nil,
expected: true,
},
{
name: "One nil slice",
a: nil,
b: []types.Ipv6Range{},
expected: false,
},
{
name: "Different lengths",
a: []types.Ipv6Range{{CidrIpv6: stringPointer("2001:db8::/32")}},
b: []types.Ipv6Range{},
expected: false,
},
{
name: "Equal slices",
a: []types.Ipv6Range{{CidrIpv6: stringPointer("2001:db8::/32")}},
b: []types.Ipv6Range{{CidrIpv6: stringPointer("2001:db8::/32")}},
expected: true,
},
{
name: "Different CIDR IPv6s",
a: []types.Ipv6Range{{CidrIpv6: stringPointer("2001:db8::/32")}},
b: []types.Ipv6Range{{CidrIpv6: stringPointer("2001:0db8::/32")}},
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := EqualsIpv6Range(tt.a, tt.b)
if got != tt.expected {
t.Errorf("EqualsIpv6Range() = %v, want %v", got, tt.expected)
}
})
}
}

func TestEqualsString(t *testing.T) {
tests := []struct {
name string
a *string
b *string
expected bool
}{
{
name: "Both nil",
a: nil,
b: nil,
expected: true,
},
{
name: "One nil",
a: nil,
b: stringPointer("test"),
expected: false,
},
{
name: "Different strings",
a: stringPointer("abc"),
b: stringPointer("ABC"),
expected: true,
},
{
name: "Same strings",
a: stringPointer("abc"),
b: stringPointer("def"),
expected: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := EqualsString(tt.a, tt.b)
if got != tt.expected {
t.Errorf("EqualsString() = %v, want %v", got, tt.expected)
}
})
}
}

// Helper function to create string pointers for testing
func stringPointer(s string) *string {
return &s
}
16 changes: 4 additions & 12 deletions internal/aws/security_group.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,20 +42,12 @@ func FindAndUpdateSg(cfg aws.Config, ips []net.Ip) {
wg.Wait()
}

func AuthorizeInboundIps(ec2Client *ec2.Client, sg types.SecurityGroup, ips []net.Ip) {
func AuthorizeInboundIps(ec2Client *ec2.Client, sg types.SecurityGroup, ips net.IpList) {
log.Default().Printf("Checking security group '%s'", *sg.GroupId)
perms := []types.IpPermission{}
for _, ip := range ips {

if ip.ExistsInAwsSg(sg) {
log.Default().Printf("Security group '%s' allows '%s'. Skipping.\n", *sg.GroupId, ip)
continue
}

perms = append(perms, ip.ToAwsIpPerms())
log.Default().Printf("Adding %v to group", ip)
}
if len(perms) == 0 {
perms := ips.ToAwsIpPerms()
if EqualsIpPerms(perms, sg.IpPermissions) {
log.Default().Printf("Security group '%s' allow group. Skipping.", *sg.GroupId)
return
}

Expand Down
31 changes: 31 additions & 0 deletions internal/net/ip_list.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package net

import (
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
)

type IpList []Ip

func (ipl IpList) ToAwsIpPerms() []types.IpPermission {
var ipv4 []types.IpRange
var ipv6 []types.Ipv6Range

for _, ip := range ipl {
cidr := ip.GetCidr()
switch ip.GetVersion() {
case IPv6:
ipv6 = append(ipv6, types.Ipv6Range{CidrIpv6: aws.String(cidr)})
default:
ipv4 = append(ipv4, types.IpRange{CidrIp: aws.String(cidr)})
}
}

return []types.IpPermission{
{
IpProtocol: aws.String("-1"),
IpRanges: ipv4,
Ipv6Ranges: ipv6,
},
}
}

0 comments on commit c9586fd

Please sign in to comment.