Skip to content

Commit

Permalink
fix: sort ports and merge adjacent ones in the nft rule
Browse files Browse the repository at this point in the history
Fixes #9009

When building a port interval set, sort the ports and merge adjacent
ranges to prevent mismatch on the nftables side.

With address sets, this was already the case due to the way IPRange
builder works, but ports need a manual implementation.

Signed-off-by: Andrey Smirnov <andrey.smirnov@siderolabs.com>
  • Loading branch information
smira committed Jul 12, 2024
1 parent cf5effa commit f14c479
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 9 deletions.
27 changes: 25 additions & 2 deletions internal/app/machined/pkg/adapters/network/nftables_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package network

import (
"cmp"
"fmt"
"net/netip"
"os"
Expand Down Expand Up @@ -109,9 +110,11 @@ func (set NfTablesSet) SetElements() []nftables.SetElement {

return elements
case SetKindPort:
elements := make([]nftables.SetElement, 0, len(set.Ports))
ports := mergeAdjacentPorts(set.Ports)

for _, p := range set.Ports {
elements := make([]nftables.SetElement, 0, len(ports))

for _, p := range ports {
from := binaryutil.BigEndian.PutUint16(p[0])
to := binaryutil.BigEndian.PutUint16(p[1] + 1)

Expand Down Expand Up @@ -157,6 +160,26 @@ func (set NfTablesSet) SetElements() []nftables.SetElement {
}
}

func mergeAdjacentPorts(in [][2]uint16) [][2]uint16 {
ports := slices.Clone(in)

slices.SortFunc(ports, func(a, b [2]uint16) int {
// sort by the lower bound of the range, assume no overlap
return cmp.Compare(a[0], b[0])
})

for i := 0; i < len(ports)-1; {
if ports[i][1]+1 >= ports[i+1][0] {
ports[i][1] = ports[i+1][1]
ports = append(ports[:i+1], ports[i+2:]...)
} else {
i++
}
}

return ports
}

// NfTablesCompiled is a compiled representation of the rule.
type NfTablesCompiled struct {
Rules [][]expr.Any
Expand Down
55 changes: 50 additions & 5 deletions internal/app/machined/pkg/adapters/network/nftables_rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,14 +526,14 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
Protocol: nethelpers.ProtocolTCP,
MatchSourcePort: &networkres.NfTablesPortMatch{
Ranges: []networkres.PortRange{
{
Lo: 1000,
Hi: 1025,
},
{
Lo: 2000,
Hi: 2000,
},
{
Lo: 1000,
Hi: 1025,
},
},
},
},
Expand Down Expand Up @@ -562,8 +562,8 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
{
Kind: network.SetKindPort,
Ports: [][2]uint16{
{1000, 1025},
{2000, 2000},
{1000, 1025},
},
},
},
Expand Down Expand Up @@ -713,3 +713,48 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
})
}
}

func TestNftablesSet(t *testing.T) { //nolint:tparallel
t.Parallel()

for _, test := range []struct {
name string

set network.NfTablesSet

expectedKeyType nftables.SetDatatype
expectedInterval bool
expectedData []nftables.SetElement
}{
{
name: "ports",

set: network.NfTablesSet{
Kind: network.SetKindPort,
Ports: [][2]uint16{
{443, 443},
{80, 81},
{5000, 5000},
{5001, 5001},
},
},

expectedKeyType: nftables.TypeInetService,
expectedInterval: true,
expectedData: []nftables.SetElement{ // network byte order
{Key: []uint8{0x0, 80}, IntervalEnd: false}, // 80 - 81
{Key: []uint8{0x0, 82}, IntervalEnd: true},
{Key: []uint8{0x1, 0xbb}, IntervalEnd: false}, // 443-443
{Key: []uint8{0x1, 0xbc}, IntervalEnd: true},
{Key: []uint8{0x13, 0x88}, IntervalEnd: false}, // 5000-5001
{Key: []uint8{0x13, 0x8a}, IntervalEnd: true},
},
},
} {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.expectedKeyType, test.set.KeyType())
assert.Equal(t, test.expectedInterval, test.set.IsInterval())
assert.Equal(t, test.expectedData, test.set.SetElements())
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,60 @@ func (s *NfTablesChainSuite) TestL4Match2() {
s.checkNftOutput(`table inet talos-test {
chain test-tcp {
type filter hook input priority filter; policy accept;
ip saddr != { 10.0.0.0/8 } tcp dport { 1023, 1024 } drop
meta nfproto ipv6 tcp dport { 1023, 1024 } drop
ip saddr != { 10.0.0.0/8 } tcp dport { 1023-1024 } drop
meta nfproto ipv6 tcp dport { 1023-1024 } drop
}
}`)
}

func (s *NfTablesChainSuite) TestL4MatchAdjacentPorts() {
chain := network.NewNfTablesChain(network.NamespaceName, "test-tcp")
chain.TypedSpec().Type = nethelpers.ChainTypeFilter
chain.TypedSpec().Hook = nethelpers.ChainHookInput
chain.TypedSpec().Priority = nethelpers.ChainPriorityFilter
chain.TypedSpec().Policy = nethelpers.VerdictAccept
chain.TypedSpec().Rules = []network.NfTablesRule{
{
MatchSourceAddress: &network.NfTablesAddressMatch{
IncludeSubnets: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
},
Invert: true,
},
MatchLayer4: &network.NfTablesLayer4Match{
Protocol: nethelpers.ProtocolTCP,
MatchDestinationPort: &network.NfTablesPortMatch{
Ranges: []network.PortRange{
{
Lo: 5000,
Hi: 5000,
},
{
Lo: 5001,
Hi: 5001,
},
{
Lo: 10250,
Hi: 10250,
},
{
Lo: 4240,
Hi: 4240,
},
},
},
},
Verdict: pointer.To(nethelpers.VerdictDrop),
},
}

s.Require().NoError(s.State().Create(s.Ctx(), chain))

s.checkNftOutput(`table inet talos-test {
chain test-tcp {
type filter hook input priority filter; policy accept;
ip saddr != { 10.0.0.0/8 } tcp dport { 4240, 5000-5001, 10250 } drop
meta nfproto ipv6 tcp dport { 4240, 5000-5001, 10250 } drop
}
}`)
}
Expand Down

0 comments on commit f14c479

Please sign in to comment.