From f14c4795e5e60bf564d584a707e261bed78bcaf8 Mon Sep 17 00:00:00 2001 From: Andrey Smirnov Date: Fri, 12 Jul 2024 14:53:23 +0400 Subject: [PATCH] fix: sort ports and merge adjacent ones in the nft rule Fixes #9009 When building a port interval set, sort the ports and merge adjacent ranges to prevent mismatch on the nftables side. With address sets, this was already the case due to the way IPRange builder works, but ports need a manual implementation. Signed-off-by: Andrey Smirnov --- .../pkg/adapters/network/nftables_rule.go | 27 ++++++++- .../adapters/network/nftables_rule_test.go | 55 ++++++++++++++++-- .../network/nftables_chain_test.go | 56 ++++++++++++++++++- 3 files changed, 129 insertions(+), 9 deletions(-) diff --git a/internal/app/machined/pkg/adapters/network/nftables_rule.go b/internal/app/machined/pkg/adapters/network/nftables_rule.go index 618b8d1dae..c79d754e28 100644 --- a/internal/app/machined/pkg/adapters/network/nftables_rule.go +++ b/internal/app/machined/pkg/adapters/network/nftables_rule.go @@ -5,6 +5,7 @@ package network import ( + "cmp" "fmt" "net/netip" "os" @@ -109,9 +110,11 @@ func (set NfTablesSet) SetElements() []nftables.SetElement { return elements case SetKindPort: - elements := make([]nftables.SetElement, 0, len(set.Ports)) + ports := mergeAdjacentPorts(set.Ports) - for _, p := range set.Ports { + elements := make([]nftables.SetElement, 0, len(ports)) + + for _, p := range ports { from := binaryutil.BigEndian.PutUint16(p[0]) to := binaryutil.BigEndian.PutUint16(p[1] + 1) @@ -157,6 +160,26 @@ func (set NfTablesSet) SetElements() []nftables.SetElement { } } +func mergeAdjacentPorts(in [][2]uint16) [][2]uint16 { + ports := slices.Clone(in) + + slices.SortFunc(ports, func(a, b [2]uint16) int { + // sort by the lower bound of the range, assume no overlap + return cmp.Compare(a[0], b[0]) + }) + + for i := 0; i < len(ports)-1; { + if ports[i][1]+1 >= ports[i+1][0] { + ports[i][1] = ports[i+1][1] + ports = append(ports[:i+1], ports[i+2:]...) + } else { + i++ + } + } + + return ports +} + // NfTablesCompiled is a compiled representation of the rule. type NfTablesCompiled struct { Rules [][]expr.Any diff --git a/internal/app/machined/pkg/adapters/network/nftables_rule_test.go b/internal/app/machined/pkg/adapters/network/nftables_rule_test.go index a898082f61..2cc1ec890e 100644 --- a/internal/app/machined/pkg/adapters/network/nftables_rule_test.go +++ b/internal/app/machined/pkg/adapters/network/nftables_rule_test.go @@ -526,14 +526,14 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel Protocol: nethelpers.ProtocolTCP, MatchSourcePort: &networkres.NfTablesPortMatch{ Ranges: []networkres.PortRange{ - { - Lo: 1000, - Hi: 1025, - }, { Lo: 2000, Hi: 2000, }, + { + Lo: 1000, + Hi: 1025, + }, }, }, }, @@ -562,8 +562,8 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel { Kind: network.SetKindPort, Ports: [][2]uint16{ - {1000, 1025}, {2000, 2000}, + {1000, 1025}, }, }, }, @@ -713,3 +713,48 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel }) } } + +func TestNftablesSet(t *testing.T) { //nolint:tparallel + t.Parallel() + + for _, test := range []struct { + name string + + set network.NfTablesSet + + expectedKeyType nftables.SetDatatype + expectedInterval bool + expectedData []nftables.SetElement + }{ + { + name: "ports", + + set: network.NfTablesSet{ + Kind: network.SetKindPort, + Ports: [][2]uint16{ + {443, 443}, + {80, 81}, + {5000, 5000}, + {5001, 5001}, + }, + }, + + expectedKeyType: nftables.TypeInetService, + expectedInterval: true, + expectedData: []nftables.SetElement{ // network byte order + {Key: []uint8{0x0, 80}, IntervalEnd: false}, // 80 - 81 + {Key: []uint8{0x0, 82}, IntervalEnd: true}, + {Key: []uint8{0x1, 0xbb}, IntervalEnd: false}, // 443-443 + {Key: []uint8{0x1, 0xbc}, IntervalEnd: true}, + {Key: []uint8{0x13, 0x88}, IntervalEnd: false}, // 5000-5001 + {Key: []uint8{0x13, 0x8a}, IntervalEnd: true}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + assert.Equal(t, test.expectedKeyType, test.set.KeyType()) + assert.Equal(t, test.expectedInterval, test.set.IsInterval()) + assert.Equal(t, test.expectedData, test.set.SetElements()) + }) + } +} diff --git a/internal/app/machined/pkg/controllers/network/nftables_chain_test.go b/internal/app/machined/pkg/controllers/network/nftables_chain_test.go index 8d87c9b8ab..d960922467 100644 --- a/internal/app/machined/pkg/controllers/network/nftables_chain_test.go +++ b/internal/app/machined/pkg/controllers/network/nftables_chain_test.go @@ -442,8 +442,60 @@ func (s *NfTablesChainSuite) TestL4Match2() { s.checkNftOutput(`table inet talos-test { chain test-tcp { type filter hook input priority filter; policy accept; - ip saddr != { 10.0.0.0/8 } tcp dport { 1023, 1024 } drop - meta nfproto ipv6 tcp dport { 1023, 1024 } drop + ip saddr != { 10.0.0.0/8 } tcp dport { 1023-1024 } drop + meta nfproto ipv6 tcp dport { 1023-1024 } drop + } +}`) +} + +func (s *NfTablesChainSuite) TestL4MatchAdjacentPorts() { + chain := network.NewNfTablesChain(network.NamespaceName, "test-tcp") + chain.TypedSpec().Type = nethelpers.ChainTypeFilter + chain.TypedSpec().Hook = nethelpers.ChainHookInput + chain.TypedSpec().Priority = nethelpers.ChainPriorityFilter + chain.TypedSpec().Policy = nethelpers.VerdictAccept + chain.TypedSpec().Rules = []network.NfTablesRule{ + { + MatchSourceAddress: &network.NfTablesAddressMatch{ + IncludeSubnets: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/8"), + }, + Invert: true, + }, + MatchLayer4: &network.NfTablesLayer4Match{ + Protocol: nethelpers.ProtocolTCP, + MatchDestinationPort: &network.NfTablesPortMatch{ + Ranges: []network.PortRange{ + { + Lo: 5000, + Hi: 5000, + }, + { + Lo: 5001, + Hi: 5001, + }, + { + Lo: 10250, + Hi: 10250, + }, + { + Lo: 4240, + Hi: 4240, + }, + }, + }, + }, + Verdict: pointer.To(nethelpers.VerdictDrop), + }, + } + + s.Require().NoError(s.State().Create(s.Ctx(), chain)) + + s.checkNftOutput(`table inet talos-test { + chain test-tcp { + type filter hook input priority filter; policy accept; + ip saddr != { 10.0.0.0/8 } tcp dport { 4240, 5000-5001, 10250 } drop + meta nfproto ipv6 tcp dport { 4240, 5000-5001, 10250 } drop } }`) }