diff --git a/go.mod b/go.mod index f3e3645e..d3e9c5a8 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/campoy/embedmd v1.0.0 github.com/containernetworking/cni v1.0.1 github.com/containernetworking/plugins v1.1.1 - github.com/coreos/go-iptables v0.6.0 + github.com/coreos/go-iptables v0.6.1-0.20220901214115-d2b8608923d1 github.com/go-kit/kit v0.9.0 github.com/kylelemons/godebug v0.0.0-20170820004349-d65d576e9348 github.com/metalmatze/signal v0.0.0-20210307161603-1c9aa721a97a diff --git a/go.sum b/go.sum index 69f8c878..d3bf95cb 100644 --- a/go.sum +++ b/go.sum @@ -106,8 +106,8 @@ github.com/containernetworking/plugins v1.1.1 h1:+AGfFigZ5TiQH00vhR8qPeSatj53eNG github.com/containernetworking/plugins v1.1.1/go.mod h1:Sr5TH/eBsGLXK/h71HeLfX19sZPp3ry5uHSkI4LPxV8= github.com/coreos/bbolt v1.3.2/go.mod h1:iRUV2dpdMOn7Bo10OQBFzIJO9kkE559Wcmn+qkEiiKk= github.com/coreos/etcd v3.3.13+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= -github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-iptables v0.6.1-0.20220901214115-d2b8608923d1 h1:zSiUKnogKeEwIIeUQP/WPH7m0BJ/IvW0VyL4muaauUY= +github.com/coreos/go-iptables v0.6.1-0.20220901214115-d2b8608923d1/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-oidc v2.1.0+incompatible/go.mod h1:CgnwVTmzoESiwO9qyAFEMiHoZ1nMCKZlZ9V6mm3/LKc= github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= diff --git a/pkg/encapsulation/cilium.go b/pkg/encapsulation/cilium.go index 9e342923..bfbb327a 100644 --- a/pkg/encapsulation/cilium.go +++ b/pkg/encapsulation/cilium.go @@ -96,8 +96,8 @@ func (f *cilium) Init(_ int) error { } // Rules is a no-op. -func (f *cilium) Rules(_ []*net.IPNet) []iptables.Rule { - return nil +func (f *cilium) Rules(_ []*net.IPNet) iptables.RuleSet { + return iptables.RuleSet{} } // Set is a no-op. diff --git a/pkg/encapsulation/encapsulation.go b/pkg/encapsulation/encapsulation.go index 21e698a0..77fab57e 100644 --- a/pkg/encapsulation/encapsulation.go +++ b/pkg/encapsulation/encapsulation.go @@ -49,7 +49,7 @@ type Encapsulator interface { Gw(net.IP, net.IP, *net.IPNet) net.IP Index() int Init(int) error - Rules([]*net.IPNet) []iptables.Rule + Rules([]*net.IPNet) iptables.RuleSet Set(*net.IPNet) error Strategy() Strategy } diff --git a/pkg/encapsulation/flannel.go b/pkg/encapsulation/flannel.go index e08af61a..9375b8f5 100644 --- a/pkg/encapsulation/flannel.go +++ b/pkg/encapsulation/flannel.go @@ -95,8 +95,8 @@ func (f *flannel) Init(_ int) error { } // Rules is a no-op. -func (f *flannel) Rules(_ []*net.IPNet) []iptables.Rule { - return nil +func (f *flannel) Rules(_ []*net.IPNet) iptables.RuleSet { + return iptables.RuleSet{} } // Set is a no-op. diff --git a/pkg/encapsulation/ipip.go b/pkg/encapsulation/ipip.go index d92b39fc..54527835 100644 --- a/pkg/encapsulation/ipip.go +++ b/pkg/encapsulation/ipip.go @@ -65,20 +65,20 @@ func (i *ipip) Init(base int) error { // Rules returns a set of iptables rules that are necessary // when traffic between nodes must be encapsulated. -func (i *ipip) Rules(nodes []*net.IPNet) []iptables.Rule { - var rules []iptables.Rule +func (i *ipip) Rules(nodes []*net.IPNet) iptables.RuleSet { + rules := iptables.RuleSet{} proto := ipipProtocolName() - rules = append(rules, iptables.NewIPv4Chain("filter", "KILO-IPIP")) - rules = append(rules, iptables.NewIPv6Chain("filter", "KILO-IPIP")) - rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) - rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) + rules.AddToAppend(iptables.NewIPv4Chain("filter", "KILO-IPIP")) + rules.AddToAppend(iptables.NewIPv6Chain("filter", "KILO-IPIP")) + rules.AddToAppend(iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) + rules.AddToAppend(iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: jump to IPIP chain", "-j", "KILO-IPIP")) for _, n := range nodes { // Accept encapsulated traffic from peers. - rules = append(rules, iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(n.IP), "filter", "KILO-IPIP", "-s", n.String(), "-m", "comment", "--comment", "Kilo: allow IPIP traffic", "-j", "ACCEPT")) } // Drop all other IPIP traffic. - rules = append(rules, iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) - rules = append(rules, iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) + rules.AddToAppend(iptables.NewIPv4Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) + rules.AddToAppend(iptables.NewIPv6Rule("filter", "INPUT", "-p", proto, "-m", "comment", "--comment", "Kilo: reject other IPIP traffic", "-j", "DROP")) return rules } diff --git a/pkg/encapsulation/noop.go b/pkg/encapsulation/noop.go index d5b39064..ad9818dd 100644 --- a/pkg/encapsulation/noop.go +++ b/pkg/encapsulation/noop.go @@ -44,8 +44,8 @@ func (n Noop) Init(_ int) error { } // Rules will also do nothing. -func (n Noop) Rules(_ []*net.IPNet) []iptables.Rule { - return nil +func (n Noop) Rules(_ []*net.IPNet) iptables.RuleSet { + return iptables.RuleSet{} } // Set will also do nothing. diff --git a/pkg/iptables/fake.go b/pkg/iptables/fake.go index 24c97dd6..985b7b33 100644 --- a/pkg/iptables/fake.go +++ b/pkg/iptables/fake.go @@ -46,6 +46,24 @@ type fakeClient struct { var _ Client = &fakeClient{} +func (f *fakeClient) InsertUnique(table, chain string, pos int, spec ...string) error { + atomic.AddUint64(&f.calls, 1) + exists, err := f.Exists(table, chain, spec...) + if err != nil { + return err + } + if exists { + return nil + } + index := pos - 1 // iptables are 1-based + rule := &rule{table: table, chain: chain, spec: spec} + prefix := append([]Rule{}, f.storage[:index]...) + suffix := append([]Rule{}, f.storage[index:]...) + prefix = append(prefix, rule) + f.storage = append(prefix, suffix...) + return nil +} + func (f *fakeClient) AppendUnique(table, chain string, spec ...string) error { atomic.AddUint64(&f.calls, 1) exists, err := f.Exists(table, chain, spec...) diff --git a/pkg/iptables/iptables.go b/pkg/iptables/iptables.go index 4cad47cb..c63c88cb 100644 --- a/pkg/iptables/iptables.go +++ b/pkg/iptables/iptables.go @@ -46,6 +46,11 @@ func ipv6Disabled() (bool, error) { // Protocol represents an IP protocol. type Protocol byte +type RuleSet struct { + appendRules []Rule // Rules to append to the chain - order matters. + prependRules []Rule // Rules to prepend to the chain - order does not matter. +} + const ( // ProtocolIPv4 represents the IPv4 protocol. ProtocolIPv4 Protocol = iota @@ -53,6 +58,21 @@ const ( ProtocolIPv6 ) +func (rs *RuleSet) AddToAppend(rule Rule) { + rs.appendRules = append(rs.appendRules, rule) +} + +func (rs *RuleSet) AddToPrepend(rule Rule) { + rs.prependRules = append(rs.prependRules, rule) +} + +func (rs *RuleSet) AppendRuleSet(other RuleSet) RuleSet { + return RuleSet{ + appendRules: append(rs.appendRules, other.appendRules...), + prependRules: append(rs.prependRules, other.prependRules...), + } +} + // GetProtocol will return a protocol from the length of an IP address. func GetProtocol(ip net.IP) Protocol { if len(ip) == net.IPv4len || ip.To4() != nil { @@ -64,6 +84,7 @@ func GetProtocol(ip net.IP) Protocol { // Client represents any type that can administer iptables rules. type Client interface { AppendUnique(table string, chain string, rule ...string) error + InsertUnique(table, chain string, pos int, rule ...string) error Delete(table string, chain string, rule ...string) error Exists(table string, chain string, rule ...string) (bool, error) List(table string, chain string) ([]string, error) @@ -75,7 +96,8 @@ type Client interface { // Rule is an interface for interacting with iptables objects. type Rule interface { - Add(Client) error + Append(Client) error + Prepend(Client) error Delete(Client) error Exists(Client) (bool, error) String() string @@ -106,7 +128,14 @@ func NewIPv6Rule(table, chain string, spec ...string) Rule { return &rule{table, chain, spec, ProtocolIPv6} } -func (r *rule) Add(client Client) error { +func (r *rule) Prepend(client Client) error { + if err := client.InsertUnique(r.table, r.chain, 1, r.spec...); err != nil { + return fmt.Errorf("failed to add iptables rule: %v", err) + } + return nil +} + +func (r *rule) Append(client Client) error { if err := client.AppendUnique(r.table, r.chain, r.spec...); err != nil { return fmt.Errorf("failed to add iptables rule: %v", err) } @@ -162,7 +191,11 @@ func NewIPv6Chain(table, name string) Rule { return &chain{table, name, ProtocolIPv6} } -func (c *chain) Add(client Client) error { +func (c *chain) Prepend(client Client) error { + return c.Append(client) +} + +func (c *chain) Append(client Client) error { // Note: `ClearChain` creates a chain if it does not exist. if err := client.ClearChain(c.table, c.chain); err != nil { return fmt.Errorf("failed to add iptables chain: %v", err) @@ -224,8 +257,9 @@ type Controller struct { registerer prometheus.Registerer sync.Mutex - rules []Rule - subscribed bool + appendRules []Rule + prependRules []Rule + subscribed bool } // ControllerOption modifies the controller's configuration. @@ -333,14 +367,21 @@ func (c *Controller) reconcile() error { c.Lock() defer c.Unlock() var rc ruleCache - for i, r := range c.rules { + if err := c.reconcileAppendRules(rc); err != nil { + return err + } + return c.reconcilePrependRules(rc) +} + +func (c *Controller) reconcileAppendRules(rc ruleCache) error { + for i, r := range c.appendRules { ok, err := rc.exists(c.client(r.Proto()), r) if err != nil { return fmt.Errorf("failed to check if rule exists: %v", err) } if !ok { - level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.rules)-i)) - if err := c.resetFromIndex(i, c.rules); err != nil { + level.Info(c.logger).Log("msg", fmt.Sprintf("applying %d iptables rules", len(c.appendRules)-i)) + if err := c.resetFromIndex(i, c.appendRules); err != nil { return fmt.Errorf("failed to add rule: %v", err) } break @@ -349,6 +390,22 @@ func (c *Controller) reconcile() error { return nil } +func (c *Controller) reconcilePrependRules(rc ruleCache) error { + for _, r := range c.prependRules { + ok, err := rc.exists(c.client(r.Proto()), r) + if err != nil { + return fmt.Errorf("failed to check if rule exists: %v", err) + } + if !ok { + level.Info(c.logger).Log("msg", "prepending iptables rule") + if err := r.Prepend(c.client(r.Proto())); err != nil { + return fmt.Errorf("failed to prepend rule: %v", err) + } + } + } + return nil +} + // resetFromIndex re-adds all rules starting from the given index. func (c *Controller) resetFromIndex(i int, rules []Rule) error { if i >= len(rules) { @@ -358,7 +415,7 @@ func (c *Controller) resetFromIndex(i int, rules []Rule) error { if err := rules[j].Delete(c.client(rules[j].Proto())); err != nil { return fmt.Errorf("failed to delete rule: %v", err) } - if err := rules[j].Add(c.client(rules[j].Proto())); err != nil { + if err := rules[j].Append(c.client(rules[j].Proto())); err != nil { return fmt.Errorf("failed to add rule: %v", err) } } @@ -383,34 +440,87 @@ func (c *Controller) deleteFromIndex(i int, rules *[]Rule) error { // Set idempotently overwrites any iptables rules previously defined // for the controller with the given set of rules. -func (c *Controller) Set(rules []Rule) error { +func (c *Controller) Set(rules RuleSet) error { c.Lock() defer c.Unlock() + if err := c.setAppendRules(rules.appendRules); err != nil { + return err + } + return c.setPrependRules(rules.prependRules) +} + +func (c *Controller) setAppendRules(appendRules []Rule) error { var i int - for ; i < len(rules); i++ { - if i < len(c.rules) { - if rules[i].String() != c.rules[i].String() { - if err := c.deleteFromIndex(i, &c.rules); err != nil { + for ; i < len(appendRules); i++ { + if i < len(c.appendRules) { + if appendRules[i].String() != c.appendRules[i].String() { + if err := c.deleteFromIndex(i, &c.appendRules); err != nil { return err } } } - if i >= len(c.rules) { - if err := rules[i].Add(c.client(rules[i].Proto())); err != nil { + if i >= len(c.appendRules) { + if err := appendRules[i].Append(c.client(appendRules[i].Proto())); err != nil { return fmt.Errorf("failed to add rule: %v", err) } - c.rules = append(c.rules, rules[i]) + c.appendRules = append(c.appendRules, appendRules[i]) } + } + err := c.deleteFromIndex(i, &c.appendRules) + if err != nil { + return fmt.Errorf("failed to delete rule: %v", err) + } + return nil +} +func (c *Controller) setPrependRules(prependRules []Rule) error { + for _, prependRule := range prependRules { + if !containsRule(c.prependRules, prependRule) { + if err := prependRule.Prepend(c.client(prependRule.Proto())); err != nil { + return fmt.Errorf("failed to add rule: %v", err) + } + c.prependRules = append(c.prependRules, prependRule) + } } - return c.deleteFromIndex(i, &c.rules) + for _, existingRule := range c.prependRules { + if !containsRule(prependRules, existingRule) { + if err := existingRule.Delete(c.client(existingRule.Proto())); err != nil { + return fmt.Errorf("failed to delete rule: %v", err) + } + c.prependRules = removeRule(c.prependRules, existingRule) + } + } + return nil +} + +func removeRule(rules []Rule, toRemove Rule) []Rule { + ret := make([]Rule, 0, len(rules)) + for _, rule := range rules { + if rule.String() != toRemove.String() { + ret = append(ret, rule) + } + } + return ret +} + +func containsRule(haystack []Rule, needle Rule) bool { + for _, element := range haystack { + if element.String() == needle.String() { + return true + } + } + return false } // CleanUp will clean up any rules created by the controller. func (c *Controller) CleanUp() error { c.Lock() defer c.Unlock() - return c.deleteFromIndex(0, &c.rules) + err := c.deleteFromIndex(0, &c.prependRules) + if err != nil { + return err + } + return c.deleteFromIndex(0, &c.appendRules) } func (c *Controller) client(p Protocol) Client { diff --git a/pkg/iptables/iptables_test.go b/pkg/iptables/iptables_test.go index 25c5c176..447eb826 100644 --- a/pkg/iptables/iptables_test.go +++ b/pkg/iptables/iptables_test.go @@ -18,70 +18,94 @@ import ( "testing" ) -var rules = []Rule{ +var appendRules = []Rule{ NewIPv4Rule("filter", "FORWARD", "-s", "10.4.0.0/16", "-j", "ACCEPT"), NewIPv4Rule("filter", "FORWARD", "-d", "10.4.0.0/16", "-j", "ACCEPT"), } +var prependRules = []Rule{ + NewIPv4Rule("filter", "FORWARD", "-s", "10.5.0.0/16", "-j", "DROP"), + NewIPv4Rule("filter", "FORWARD", "-s", "10.6.0.0/16", "-j", "DROP"), +} + func TestSet(t *testing.T) { for _, tc := range []struct { - name string - sets [][]Rule - out []Rule - actions []func(Client) error + name string + sets []RuleSet + appendOut []Rule + prependOut []Rule + storageOut []Rule + actions []func(Client) error }{ { name: "empty", }, { name: "single", - sets: [][]Rule{ - {rules[0]}, + sets: []RuleSet{ + {appendRules: []Rule{appendRules[0]}}, }, - out: []Rule{rules[0]}, + appendOut: []Rule{appendRules[0]}, + storageOut: []Rule{appendRules[0]}, }, { name: "two rules", - sets: [][]Rule{ - {rules[0], rules[1]}, + sets: []RuleSet{ + {appendRules: []Rule{appendRules[0], appendRules[1]}}, }, - out: []Rule{rules[0], rules[1]}, + appendOut: []Rule{appendRules[0], appendRules[1]}, + storageOut: []Rule{appendRules[0], appendRules[1]}, }, { name: "multiple", - sets: [][]Rule{ - {rules[0], rules[1]}, - {rules[1]}, + sets: []RuleSet{ + {appendRules: []Rule{appendRules[0], appendRules[1]}}, + {appendRules: []Rule{appendRules[1]}}, }, - out: []Rule{rules[1]}, + appendOut: []Rule{appendRules[1]}, + storageOut: []Rule{appendRules[1]}, }, { name: "re-add", - sets: [][]Rule{ - {rules[0], rules[1]}, + sets: []RuleSet{ + {appendRules: []Rule{appendRules[0], appendRules[1]}}, }, - out: []Rule{rules[0], rules[1]}, + appendOut: []Rule{appendRules[0], appendRules[1]}, + storageOut: []Rule{appendRules[0], appendRules[1]}, actions: []func(c Client) error{ func(c Client) error { - return rules[0].Delete(c) + return appendRules[0].Delete(c) }, func(c Client) error { - return rules[1].Delete(c) + return appendRules[1].Delete(c) }, }, }, { name: "order", - sets: [][]Rule{ - {rules[0], rules[1]}, + sets: []RuleSet{ + {appendRules: []Rule{appendRules[0], appendRules[1]}}, }, - out: []Rule{rules[0], rules[1]}, + appendOut: []Rule{appendRules[0], appendRules[1]}, + storageOut: []Rule{appendRules[0], appendRules[1]}, actions: []func(c Client) error{ func(c Client) error { - return rules[0].Delete(c) + return appendRules[0].Delete(c) }, }, }, + { + name: "append and prepend", + sets: []RuleSet{ + { + prependRules: []Rule{prependRules[0], prependRules[1]}, + appendRules: []Rule{appendRules[0], appendRules[1]}, + }, + }, + appendOut: []Rule{appendRules[0], appendRules[1]}, + prependOut: []Rule{prependRules[0], prependRules[1]}, + storageOut: []Rule{prependRules[1], prependRules[0], appendRules[0], appendRules[1]}, + }, } { client := &fakeClient{} controller, err := New(WithClients(client, client)) @@ -90,7 +114,7 @@ func TestSet(t *testing.T) { } for i := range tc.sets { if err := controller.Set(tc.sets[i]); err != nil { - t.Fatalf("test case %q: got unexpected error seting rule set %d: %v", tc.name, i, err) + t.Fatalf("test case %q: got unexpected error setting rule set %d: %v", tc.name, i, err) } } for i, f := range tc.actions { @@ -101,21 +125,30 @@ func TestSet(t *testing.T) { if err := controller.reconcile(); err != nil { t.Fatalf("test case %q: got unexpected error %v", tc.name, err) } - if len(tc.out) != len(client.storage) { - t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.out), len(client.storage)) + if len(tc.storageOut) != len(client.storage) { + t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.storageOut), len(client.storage)) } else { - for i := range tc.out { - if tc.out[i].String() != client.storage[i].String() { - t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.out[i], client.storage[i]) + for i := range tc.storageOut { + if tc.storageOut[i].String() != client.storage[i].String() { + t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.storageOut[i], client.storage[i]) } } } - if len(tc.out) != len(controller.rules) { - t.Errorf("test case %q: expected %d rules in controller, got %d", tc.name, len(tc.out), len(controller.rules)) + if len(tc.appendOut) != len(controller.appendRules) { + t.Errorf("test case %q: expected %d appendRules in controller, got %d", tc.name, len(tc.appendOut), len(controller.appendRules)) } else { - for i := range tc.out { - if tc.out[i].String() != controller.rules[i].String() { - t.Errorf("test case %q: expected rule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.out[i], controller.rules[i]) + for i := range tc.appendOut { + if tc.appendOut[i].String() != controller.appendRules[i].String() { + t.Errorf("test case %q: expected appendRule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.appendOut[i], controller.appendRules[i]) + } + } + } + if len(tc.prependOut) != len(controller.prependRules) { + t.Errorf("test case %q: expected %d prependRules in controller, got %d", tc.name, len(tc.prependOut), len(controller.prependRules)) + } else { + for i := range tc.prependOut { + if tc.prependOut[i].String() != controller.prependRules[i].String() { + t.Errorf("test case %q: expected prependRule %d in controller to be equal: expected %v, got %v", tc.name, i, tc.prependOut[i], controller.prependRules[i]) } } } @@ -124,20 +157,26 @@ func TestSet(t *testing.T) { func TestCleanUp(t *testing.T) { for _, tc := range []struct { - name string - rules []Rule + name string + appendRules []Rule + prependRules []Rule }{ { - name: "empty", - rules: nil, + name: "empty", + appendRules: nil, + }, + { + name: "single append", + appendRules: []Rule{appendRules[0]}, }, { - name: "single", - rules: []Rule{rules[0]}, + name: "multiple append", + appendRules: []Rule{appendRules[0], appendRules[1]}, }, { - name: "multiple", - rules: []Rule{rules[0], rules[1]}, + name: "multiple append and prepend", + appendRules: []Rule{appendRules[0], appendRules[1]}, + prependRules: []Rule{prependRules[0], prependRules[1]}, }, } { client := &fakeClient{} @@ -145,11 +184,12 @@ func TestCleanUp(t *testing.T) { if err != nil { t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err) } - if err := controller.Set(tc.rules); err != nil { + ruleSet := RuleSet{appendRules: tc.appendRules, prependRules: tc.prependRules} + if err := controller.Set(ruleSet); err != nil { t.Fatalf("test case %q: Set should not fail: %v", tc.name, err) } - if len(client.storage) != len(tc.rules) { - t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(tc.rules), len(client.storage)) + if len(client.storage) != len(tc.appendRules)+len(tc.prependRules) { + t.Errorf("test case %q: expected %d rules in storage, got %d rules", tc.name, len(ruleSet.appendRules)+len(ruleSet.prependRules), len(client.storage)) } if err := controller.CleanUp(); err != nil { t.Errorf("test case %q: got unexpected error: %v", tc.name, err) @@ -159,3 +199,42 @@ func TestCleanUp(t *testing.T) { } } } + +func TestReconcile(t *testing.T) { + for _, tc := range []struct { + name string + appendRules []Rule + prependRules []Rule + storageOut []Rule + }{ + { + name: "append and prepend rules", + appendRules: []Rule{appendRules[0], appendRules[1]}, + prependRules: []Rule{prependRules[0], prependRules[1]}, + storageOut: []Rule{prependRules[1], prependRules[0], appendRules[0], appendRules[1]}, + }, + } { + client := &fakeClient{} + controller, err := New(WithClients(client, client)) + if err != nil { + t.Fatalf("test case %q: got unexpected error instantiating controller: %v", tc.name, err) + } + controller.appendRules = tc.appendRules + controller.prependRules = tc.prependRules + + err = controller.reconcile() + if err != nil { + t.Fatalf("test case %q: unexpected error during reconcile: %v", tc.name, err) + } + + if len(tc.storageOut) != len(client.storage) { + t.Errorf("test case %q: expected %d rules in storage, got %d", tc.name, len(tc.storageOut), len(client.storage)) + } else { + for i := range tc.storageOut { + if tc.storageOut[i].String() != client.storage[i].String() { + t.Errorf("test case %q: expected rule %d in storage to be equal: expected %v, got %v", tc.name, i, tc.storageOut[i], client.storage[i]) + } + } + } + } +} diff --git a/pkg/iptables/metrics.go b/pkg/iptables/metrics.go index b262937d..dee77b72 100644 --- a/pkg/iptables/metrics.go +++ b/pkg/iptables/metrics.go @@ -51,6 +51,15 @@ func (m *metricsClientWrapper) AppendUnique(table string, chain string, rule ... return m.client.AppendUnique(table, chain, rule...) } +func (m *metricsClientWrapper) InsertUnique(table, chain string, pos int, rule ...string) error { + m.operationCounter.With(prometheus.Labels{ + "operation": "InsertUnique", + "table": table, + "chain": chain, + }).Inc() + return m.client.InsertUnique(table, chain, pos, rule...) +} + func (m *metricsClientWrapper) Delete(table string, chain string, rule ...string) error { m.operationCounter.With(prometheus.Labels{ "operation": "Delete", diff --git a/pkg/iptables/rulecache_test.go b/pkg/iptables/rulecache_test.go index f3f1ead7..3561f274 100644 --- a/pkg/iptables/rulecache_test.go +++ b/pkg/iptables/rulecache_test.go @@ -29,21 +29,21 @@ func TestRuleCache(t *testing.T) { { name: "empty", rules: nil, - check: []Rule{rules[0]}, + check: []Rule{appendRules[0]}, out: []bool{false}, calls: 1, }, { name: "single negative", - rules: []Rule{rules[1]}, - check: []Rule{rules[0]}, + rules: []Rule{appendRules[1]}, + check: []Rule{appendRules[0]}, out: []bool{false}, calls: 1, }, { name: "single positive", - rules: []Rule{rules[1]}, - check: []Rule{rules[1]}, + rules: []Rule{appendRules[1]}, + check: []Rule{appendRules[1]}, out: []bool{true}, calls: 1, }, @@ -56,29 +56,29 @@ func TestRuleCache(t *testing.T) { }, { name: "rule on chain means chain exists", - rules: []Rule{rules[0]}, - check: []Rule{rules[0], &chain{"filter", "FORWARD", ProtocolIPv4}}, + rules: []Rule{appendRules[0]}, + check: []Rule{appendRules[0], &chain{"filter", "FORWARD", ProtocolIPv4}}, out: []bool{true, true}, calls: 1, }, { name: "rule on chain does not mean table is fully populated", - rules: []Rule{rules[0], &chain{"filter", "INPUT", ProtocolIPv4}}, - check: []Rule{rules[0], &chain{"filter", "OUTPUT", ProtocolIPv4}, &chain{"filter", "INPUT", ProtocolIPv4}}, + rules: []Rule{appendRules[0], &chain{"filter", "INPUT", ProtocolIPv4}}, + check: []Rule{appendRules[0], &chain{"filter", "OUTPUT", ProtocolIPv4}, &chain{"filter", "INPUT", ProtocolIPv4}}, out: []bool{true, false, true}, calls: 2, }, { name: "multiple rules on chain", - rules: []Rule{rules[0], rules[1]}, - check: []Rule{rules[0], rules[1], &chain{"filter", "FORWARD", ProtocolIPv4}}, + rules: []Rule{appendRules[0], appendRules[1]}, + check: []Rule{appendRules[0], appendRules[1], &chain{"filter", "FORWARD", ProtocolIPv4}}, out: []bool{true, true, true}, calls: 1, }, { name: "checking rule on chain does not mean chain exists", rules: nil, - check: []Rule{rules[0], &chain{"filter", "FORWARD", ProtocolIPv4}}, + check: []Rule{appendRules[0], &chain{"filter", "FORWARD", ProtocolIPv4}}, out: []bool{false, false}, calls: 2, }, @@ -101,7 +101,8 @@ func TestRuleCache(t *testing.T) { client := &fakeClient{} controller.v4 = client controller.v6 = client - if err := controller.Set(tc.rules); err != nil { + ruleSet := RuleSet{appendRules: tc.rules} + if err := controller.Set(ruleSet); err != nil { t.Fatalf("test case %q: Set should not fail: %v", tc.name, err) } // Reset the client's calls so we can examine how many times diff --git a/pkg/mesh/mesh.go b/pkg/mesh/mesh.go index a7c70584..19304e16 100644 --- a/pkg/mesh/mesh.go +++ b/pkg/mesh/mesh.go @@ -526,7 +526,8 @@ func (m *Mesh) applyTopology() { } } - ipRules = append(m.enc.Rules(cidrs), ipRules...) + encIpRules := m.enc.Rules(cidrs) + ipRules = encIpRules.AppendRuleSet(ipRules) // If we are handling local routes, ensure the local // tunnel has an IP address. diff --git a/pkg/mesh/routes.go b/pkg/mesh/routes.go index 7ede6ab1..bff920ea 100644 --- a/pkg/mesh/routes.go +++ b/pkg/mesh/routes.go @@ -311,12 +311,12 @@ func encapsulateRoute(route *netlink.Route, encapsulate encapsulation.Strategy, } // Rules returns the iptables rules required by the local node. -func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { - var rules []iptables.Rule - rules = append(rules, iptables.NewIPv4Chain("nat", "KILO-NAT")) - rules = append(rules, iptables.NewIPv6Chain("nat", "KILO-NAT")) +func (t *Topology) Rules(cni, iptablesForwardRule bool) iptables.RuleSet { + rules := iptables.RuleSet{} + rules.AddToAppend(iptables.NewIPv4Chain("nat", "KILO-NAT")) + rules.AddToAppend(iptables.NewIPv6Chain("nat", "KILO-NAT")) if cni { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "nat", "POSTROUTING", "-s", t.subnet.String(), "-m", "comment", "--comment", "Kilo: jump to KILO-NAT chain", "-j", "KILO-NAT")) // Some linux distros or docker will set forward DROP in the filter table. // To still be able to have pod to pod communication we need to ALLOW packets from and to pod CIDRs within a location. // Leader nodes will forward packets from all nodes within a location because they act as a gateway for them. @@ -326,55 +326,51 @@ func (t *Topology) Rules(cni, iptablesForwardRule bool) []iptables.Rule { if s.location == t.location { // Make sure packets to and from pod cidrs are not dropped in the forward chain. for _, c := range s.cidrs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the pod subnet", "-s", c.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the pod subnet", "-d", c.String(), "-j", "ACCEPT")) } // Make sure packets to and from allowed location IPs are not dropped in the forward chain. for _, c := range s.allowedLocationIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from allowed location IPs", "-s", c.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to allowed location IPs", "-d", c.String(), "-j", "ACCEPT")) } // Make sure packets to and from private IPs are not dropped in the forward chain. for _, c := range s.privateIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from private IPs", "-s", oneAddressCIDR(c).String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(c), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to private IPs", "-d", oneAddressCIDR(c).String(), "-j", "ACCEPT")) } } } } else if iptablesForwardRule { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT")) - rules = append(rules, iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets from the node's pod subnet", "-s", t.subnet.String(), "-j", "ACCEPT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(t.subnet.IP), "filter", "FORWARD", "-m", "comment", "--comment", "Kilo: forward packets to the node's pod subnet", "-d", t.subnet.String(), "-j", "ACCEPT")) } } for _, s := range t.segments { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(s.wireGuardIP), "nat", "KILO-NAT", "-d", oneAddressCIDR(s.wireGuardIP).String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for WireGuared IPs", "-j", "RETURN")) for _, aip := range s.allowedIPs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for known IPs", "-j", "RETURN")) } // Make sure packets to allowed location IPs go through the KILO-NAT chain, so they can be MASQUERADEd, // Otherwise packets to these destinations will reach the destination, but never find their way back. // We only want to NAT in locations of the corresponding allowed location IPs. if t.location == s.location { for _, alip := range s.allowedLocationIPs { - rules = append(rules, - iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), - ) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(alip.IP), "nat", "POSTROUTING", "-d", alip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT")) } } } for _, p := range t.peers { for _, aip := range p.AllowedIPs { - rules = append(rules, - iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT"), - iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN"), - ) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "POSTROUTING", "-s", aip.String(), "-m", "comment", "--comment", "Kilo: jump to NAT chain", "-j", "KILO-NAT")) + rules.AddToPrepend(iptables.NewRule(iptables.GetProtocol(aip.IP), "nat", "KILO-NAT", "-d", aip.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for peers", "-j", "RETURN")) } } for _, s := range t.serviceCIDRs { - rules = append(rules, iptables.NewRule(iptables.GetProtocol(s.IP), "nat", "KILO-NAT", "-d", s.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for service CIDRs", "-j", "RETURN")) + rules.AddToAppend(iptables.NewRule(iptables.GetProtocol(s.IP), "nat", "KILO-NAT", "-d", s.String(), "-m", "comment", "--comment", "Kilo: do not NAT packets destined for service CIDRs", "-j", "RETURN")) } - rules = append(rules, iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) - rules = append(rules, iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) + rules.AddToAppend(iptables.NewIPv4Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) + rules.AddToAppend(iptables.NewIPv6Rule("nat", "KILO-NAT", "-m", "comment", "--comment", "Kilo: NAT remaining packets", "-j", "MASQUERADE")) return rules } diff --git a/vendor/github.com/coreos/go-iptables/iptables/iptables.go b/vendor/github.com/coreos/go-iptables/iptables/iptables.go index 85047e59..1e7ad245 100644 --- a/vendor/github.com/coreos/go-iptables/iptables/iptables.go +++ b/vendor/github.com/coreos/go-iptables/iptables/iptables.go @@ -109,6 +109,7 @@ func Timeout(timeout int) option { // For backwards compatibility, by default always uses IPv4 and timeout 0. // i.e. you can create an IPv6 IPTables using a timeout of 5 seconds passing // the IPFamily and Timeout options as follow: +// // ip6t := New(IPFamily(ProtocolIPv6), Timeout(5)) func New(opts ...option) (*IPTables, error) { @@ -185,6 +186,20 @@ func (ipt *IPTables) Insert(table, chain string, pos int, rulespec ...string) er return ipt.run(cmd...) } +// InsertUnique acts like Insert except that it won't insert a duplicate (no matter the position in the chain) +func (ipt *IPTables) InsertUnique(table, chain string, pos int, rulespec ...string) error { + exists, err := ipt.Exists(table, chain, rulespec...) + if err != nil { + return err + } + + if !exists { + return ipt.Insert(table, chain, pos, rulespec...) + } + + return nil +} + // Append appends rulespec to specified table/chain func (ipt *IPTables) Append(table, chain string, rulespec ...string) error { cmd := append([]string{"-t", table, "-A", chain}, rulespec...) @@ -219,6 +234,16 @@ func (ipt *IPTables) DeleteIfExists(table, chain string, rulespec ...string) err return err } +// List rules in specified table/chain +func (ipt *IPTables) ListById(table, chain string, id int) (string, error) { + args := []string{"-t", table, "-S", chain, strconv.Itoa(id)} + rule, err := ipt.executeList(args) + if err != nil { + return "", err + } + return rule[0], nil +} + // List rules in specified table/chain func (ipt *IPTables) List(table, chain string) ([]string, error) { args := []string{"-t", table, "-S", chain} @@ -510,7 +535,9 @@ func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error { syscall.Close(fmu.fd) return err } - defer ul.Unlock() + defer func() { + _ = ul.Unlock() + }() } var stderr bytes.Buffer @@ -619,7 +646,7 @@ func iptablesHasWaitCommand(v1 int, v2 int, v3 int) bool { return false } -//Checks if an iptablse version is after 1.6.0, when --wait support second +// Checks if an iptablse version is after 1.6.0, when --wait support second func iptablesWaitSupportSecond(v1 int, v2 int, v3 int) bool { if v1 > 1 { return true diff --git a/vendor/modules.txt b/vendor/modules.txt index 03d222ae..ac67d78b 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -39,7 +39,7 @@ github.com/containernetworking/plugins/pkg/ns github.com/containernetworking/plugins/pkg/utils/sysctl github.com/containernetworking/plugins/plugins/ipam/host-local/backend github.com/containernetworking/plugins/plugins/ipam/host-local/backend/allocator -# github.com/coreos/go-iptables v0.6.0 +# github.com/coreos/go-iptables v0.6.1-0.20220901214115-d2b8608923d1 ## explicit; go 1.16 github.com/coreos/go-iptables/iptables # github.com/davecgh/go-spew v1.1.1