Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set the rule handle after flush #88

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
23 changes: 21 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type Conn struct {
NetNS int // Network namespace netlink will interact with.
sync.Mutex
messages []netlink.Message
rules []*Rule
alexispires marked this conversation as resolved.
Show resolved Hide resolved
err error
}

Expand All @@ -43,6 +44,7 @@ func (cc *Conn) Flush() error {
cc.Lock()
defer func() {
cc.messages = nil
cc.rules = nil
cc.Unlock()
}()
if len(cc.messages) == 0 {
Expand All @@ -63,8 +65,25 @@ func (cc *Conn) Flush() error {
return fmt.Errorf("SendMessages: %w", err)
}

if _, err := conn.Receive(); err != nil {
return fmt.Errorf("Receive: %w", err)
echoedRules := 0

for len(cc.rules) > echoedRules {
rmsg, err := conn.Receive()

alexispires marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return fmt.Errorf("Receive: %w", err)
}

for _, msg := range rmsg {
if msg.Header.Type == ruleHeaderType {
rule, err := ruleFromMsg(msg)
if err == nil {
cc.rules[echoedRules].Handle = rule.Handle
echoedRules++
}
}
}

}

return nil
Expand Down
74 changes: 74 additions & 0 deletions nftables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3980,3 +3980,77 @@ func TestStatelessNAT(t *testing.T) {
t.Fatal(err)
}
}

func TestHandleBack(t *testing.T) {

// Create a new network namespace to test these operations,
// and tear down the namespace at test completion.
c, newNS := openSystemNFTConn(t)
defer cleanupSystemNFTConn(t, newNS)
// Clear all rules at the beginning + end of the test.
c.FlushRuleset()
defer c.FlushRuleset()

filter := c.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "filter",
})

prerouting := c.AddChain(&nftables.Chain{
Name: "base-chain",
Table: filter,
Type: nftables.ChainTypeFilter,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityFilter,
})

var rulesCreated []*nftables.Rule

rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{
Table: filter,
Chain: prerouting,
Exprs: []expr.Any{
&expr.Verdict{
// [ immediate reg 0 drop ]
Kind: expr.VerdictDrop,
},
},
}))

rulesCreated = append(rulesCreated, c.AddRule(&nftables.Rule{
Table: filter,
Chain: prerouting,
Exprs: []expr.Any{
&expr.Verdict{
// [ immediate reg 0 drop ]
Kind: expr.VerdictDrop,
},
},
}))

for i, r := range rulesCreated {
if r.Handle != 0 {
t.Fatalf("unexpected handle value at %d", i)
}
}

if err := c.Flush(); err != nil {
t.Fatal(err)
}

rulesGetted, _ := c.GetRule(filter, prerouting)

if len(rulesGetted) != len(rulesCreated) {
t.Fatalf("Bad ruleset lenght got %d want %d", len(rulesGetted), len(rulesCreated))
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this test running many workers concurrently? In other words: why is not sufficient to test with 1 worker?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To test the behaviors with concurrency as exprimed here: #88 (comment)
IMO it's safer to keep it to identify regression on concurent access. But it's not specific on this part of lib, I think concurrency have to be tested on the whole lib.

for i, r := range rulesGetted {
if r.Handle == 0 {
t.Fatalf("handle value is empty at %d", i)
}

if r.Handle != rulesCreated[i].Handle {
t.Fatalf("mismatched handle at %d", i)
}
}
}
2 changes: 2 additions & 0 deletions rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ func (cc *Conn) AddRule(r *Rule) *Rule {
Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...),
})

cc.rules = append(cc.rules, r)

return r
}

Expand Down