diff --git a/chain.go b/chain.go index 74caca5..48ebadf 100644 --- a/chain.go +++ b/chain.go @@ -123,7 +123,7 @@ func (cc *Conn) AddChain(c *Chain) *Chain { {Type: unix.NFTA_CHAIN_TYPE, Data: []byte(c.Type + "\x00")}, })...) } - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWCHAIN), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -144,7 +144,7 @@ func (cc *Conn) DelChain(c *Chain) { {Type: unix.NFTA_CHAIN_NAME, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELCHAIN), Flags: netlink.Request | netlink.Acknowledge, @@ -162,7 +162,7 @@ func (cc *Conn) FlushChain(c *Chain) { {Type: unix.NFTA_RULE_TABLE, Data: []byte(c.Table.Name + "\x00")}, {Type: unix.NFTA_RULE_CHAIN, Data: []byte(c.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge, diff --git a/conn.go b/conn.go index 3768645..533336d 100644 --- a/conn.go +++ b/conn.go @@ -24,6 +24,10 @@ import ( "golang.org/x/sys/unix" ) +type Entity interface { + HandleResponse(netlink.Message) +} + // A Conn represents a netlink connection of the nftables family. // // All methods return their input, so that variables can be defined from string @@ -31,11 +35,13 @@ import ( // // Commands are buffered. Flush sends all buffered commands in a single batch. type Conn struct { - TestDial nltest.Func // for testing only; passed to nltest.Dial - NetNS int // Network namespace netlink will interact with. sync.Mutex - messages []netlink.Message - err error + TestDial nltest.Func // for testing only; passed to nltest.Dial + NetNS int // Network namespace netlink will interact with. + entities map[int]Entity + messagesMu sync.Mutex + messages []netlink.Message + err error } // Flush sends all buffered commands in a single batch to nftables. @@ -43,6 +49,7 @@ func (cc *Conn) Flush() error { cc.Lock() defer func() { cc.messages = nil + cc.entities = nil cc.Unlock() }() if len(cc.messages) == 0 { @@ -59,15 +66,99 @@ func (cc *Conn) Flush() error { defer conn.Close() - if _, err := conn.SendMessages(batch(cc.messages)); err != nil { + cc.endBatch(cc.messages) + + if _, err = conn.SendMessages(cc.messages); err != nil { return fmt.Errorf("SendMessages: %w", err) } - if _, err := conn.Receive(); err != nil { - return fmt.Errorf("Receive: %w", err) + // Retrieving of seq number associated to entities + entitiesBySeq := make(map[uint32]Entity) + for i, e := range cc.entities { + entitiesBySeq[cc.messages[i].Header.Sequence] = e + } + + // Trigger entities callback + msg, err := cc.checkReceive(conn) + if err != nil { + return err + } + + for msg { + rmsg, err := conn.Receive() + if err != nil { + return fmt.Errorf("Receive: %w", err) + } + + for _, msg := range rmsg { + if e, ok := entitiesBySeq[msg.Header.Sequence]; ok { + e.HandleResponse(msg) + + } + } + msg, err = cc.checkReceive(conn) + if err != nil { + return err + } + } + + return err +} + +// putMessage store netlink message to sent after +func (cc *Conn) putMessage(msg netlink.Message) int { + cc.messagesMu.Lock() + defer cc.messagesMu.Unlock() + + if cc.messages == nil { + cc.messages = append(cc.messages, netlink.Message{ + Header: netlink.Header{ + Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), + Flags: netlink.Request, + }, + Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), + }) + } + + cc.messages = append(cc.messages, msg) + + return len(cc.messages) - 1 +} + +// PutEntity store entity to relate to netlink response +func (cc *Conn) PutEntity(i int, e Entity) { + if cc.entities == nil { + cc.entities = make(map[int]Entity) + } + cc.entities[i] = e +} + +func (cc *Conn) checkReceive(c *netlink.Conn) (bool, error) { + if cc.TestDial != nil { + return false, nil + } + + sc, err := c.SyscallConn() + + if err != nil { + return false, fmt.Errorf("SyscallConn error: %w", err) + } + + var n int + + sc.Control(func(fd uintptr) { + var fdSet unix.FdSet + fdSet.Zero() + fdSet.Set(int(fd)) + + n, err = unix.Select(int(fd)+1, &fdSet, nil, nil, &unix.Timeval{}) + }) + + if err == nil && n > 0 { + return true, nil } - return nil + return false, err } // FlushRuleset flushes the entire ruleset. See also @@ -75,7 +166,7 @@ func (cc *Conn) Flush() error { func (cc *Conn) FlushRuleset() { cc.Lock() defer cc.Unlock() - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -116,26 +207,16 @@ func (cc *Conn) marshalExpr(e expr.Any) []byte { return b } -func batch(messages []netlink.Message) []netlink.Message { - batch := []netlink.Message{ - { - Header: netlink.Header{ - Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_BEGIN), - Flags: netlink.Request, - }, - Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), - }, - } +func (cc *Conn) endBatch(messages []netlink.Message) { - batch = append(batch, messages...) + cc.messagesMu.Lock() + defer cc.messagesMu.Unlock() - batch = append(batch, netlink.Message{ + cc.messages = append(cc.messages, netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType(unix.NFNL_MSG_BATCH_END), Flags: netlink.Request, }, Data: extraHeader(0, unix.NFNL_SUBSYS_NFTABLES), }) - - return batch } diff --git a/go.mod b/go.mod index dfd5143..c5bf29e 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.12 require ( github.com/koneu/natend v0.0.0-20150829182554-ec0926ea948d - github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b + github.com/mdlayher/netlink v1.0.0 github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc golang.org/x/net v0.0.0-20191028085509-fe3aa8a45271 // indirect - golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c + golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632 ) diff --git a/go.sum b/go.sum index 452fd2b..d20b949 100644 --- a/go.sum +++ b/go.sum @@ -8,6 +8,8 @@ github.com/koneu/natend v0.0.0-20150829182554-ec0926ea948d/go.mod h1:QHb4k4cr1fQ github.com/mdlayher/netlink v0.0.0-20190409211403-11939a169225/go.mod h1:eQB3mZE4aiYnlUsyGGCOpPETfdQq4Jhsgf1fk3cwQaA= github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b h1:W3er9pI7mt2gOqOWzwvx20iJ8Akiqz1mUMTxU6wdvl8= github.com/mdlayher/netlink v0.0.0-20191009155606-de872b0d824b/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M= +github.com/mdlayher/netlink v1.0.0 h1:vySPY5Oxnn/8lxAPn2cK6kAzcZzYJl3KriSLO46OT18= +github.com/mdlayher/netlink v1.0.0/go.mod h1:KxeJAFOFLG6AjpyDkQ/iIhxygIUKD+vcwqcnu43w/+M= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc h1:R83G5ikgLMxrBvLh22JhdfI8K6YXEPHx5P03Uu3DRs4= github.com/vishvananda/netns v0.0.0-20180720170159-13995c7128cc/go.mod h1:ZjcWmFBXmLKZu9Nxj3WKYEafiSqer2rnvPr0en9UNpI= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -25,4 +27,6 @@ golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c h1:S/FtSvpNLtFBgjTqcKsRpsa6a golang.org/x/sys v0.0.0-20191029155521-f43be2a4598c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191112214154-59a1497f0cea h1:Mz1TMnfJDRJLk8S8OPCoJYgrsp/Se/2TBre2+vwX128= golang.org/x/sys v0.0.0-20191113150313-8ad342257130 h1:+sdNBpwFF05NvMnEyGynbOs/Gr2LQwORWEPKXuEXxzU= +golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632 h1:ateQkYCVYo8UwIBvoR3zj1Dh2K6Op/n3GxemXfB44/Y= +golang.org/x/sys v0.0.0-20200106114638-5f8ca72cd632/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/nftables_test.go b/nftables_test.go index e3337ef..bc966d6 100644 --- a/nftables_test.go +++ b/nftables_test.go @@ -23,6 +23,7 @@ import ( "reflect" "runtime" "strings" + "sync" "testing" "github.com/google/nftables" @@ -3980,3 +3981,84 @@ func TestStatelessNAT(t *testing.T) { t.Fatal(err) } } + +func TestIntegrationAddRule(t *testing.T) { + + // Create a new network namespace to test these operations, + // and tear down the namespace at test completion. + c, newNS := openSystemNFTConn(t) + defer cleanupSystemNFTConn(t, newNS) + // Clear all rules at the beginning + end of the test. + c.FlushRuleset() + defer c.FlushRuleset() + + filter := c.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "filter", + }) + + chain := c.AddChain(&nftables.Chain{ + Name: "chain", + Table: filter, + Type: nftables.ChainTypeFilter, + Hooknum: nftables.ChainHookPrerouting, + Priority: nftables.ChainPriorityFilter, + }) + + c.Flush() + + execN := func(w int, n int) { + c := &nftables.Conn{NetNS: int(newNS)} + + for i := 0; i < n; i++ { + + r := c.AddRule(&nftables.Rule{ + Table: filter, + Chain: chain, + UserData: []byte(fmt.Sprintf("%d-%d", w, i)), + Exprs: []expr.Any{ + &expr.Verdict{ + // [ immediate reg 0 drop ] + Kind: expr.VerdictDrop, + }, + }, + }) + + if r.Handle != 0 { + t.Fatalf("unexpected handle value at %d", i) + } + + if err := c.Flush(); err != nil { + t.Fatal(err) + } + + if r.Handle == 0 { + t.Fatalf("handle value is empty at %d", i) + } + + rulesGetted, _ := c.GetRule(filter, chain) + + for i, rg := range rulesGetted { + if bytes.Equal(rg.UserData, r.UserData) && rg.Handle != r.Handle { + t.Fatalf("mismatched handle at %d-%d, got: %d, want: %d", w, i, r.Handle, rg.Handle) + } + } + } + } + + const ( + workers = 16 + iterations = 256 + ) + + var wg sync.WaitGroup + wg.Add(workers) + for i := 0; i < workers; i++ { + go func(n int) { + defer wg.Done() + execN(n, iterations) + }(i) + } + + wg.Wait() +} diff --git a/obj.go b/obj.go index f3627df..99d51e0 100644 --- a/obj.go +++ b/obj.go @@ -43,7 +43,7 @@ func (cc *Conn) AddObj(o Obj) Obj { return nil } - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWOBJ), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, diff --git a/rule.go b/rule.go index 48d79d1..d878b5e 100644 --- a/rule.go +++ b/rule.go @@ -122,13 +122,16 @@ func (cc *Conn) AddRule(r *Rule) *Rule { flags = netlink.Request | netlink.Acknowledge | netlink.Create | unix.NLM_F_ECHO | unix.NLM_F_APPEND } - cc.messages = append(cc.messages, netlink.Message{ + m := netlink.Message{ Header: netlink.Header{ Type: ruleHeaderType, Flags: flags, }, Data: append(extraHeader(uint8(r.Table.Family), 0), msgData...), - }) + } + + i := cc.putMessage(m) + cc.PutEntity(i, r) return r } @@ -149,7 +152,7 @@ func (cc *Conn) DelRule(r *Rule) error { })...) flags := netlink.Request | netlink.Acknowledge - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: flags, @@ -160,6 +163,16 @@ func (cc *Conn) DelRule(r *Rule) error { return nil } +// HandleResponse retrieves Handle in netlink response +func (r *Rule) HandleResponse(msg netlink.Message) { + rule, err := ruleFromMsg(msg) + if err != nil { + return + } + + r.Handle = rule.Handle +} + func exprsFromMsg(b []byte) ([]expr.Any, error) { ad, err := netlink.NewAttributeDecoder(b) if err != nil { diff --git a/set.go b/set.go index 2b9ee7e..f45e0be 100644 --- a/set.go +++ b/set.go @@ -165,7 +165,7 @@ func (cc *Conn) SetAddElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -327,7 +327,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { netlink.Attribute{Type: unix.NFTA_SET_USERDATA, Data: []byte("\x00\x04\x02\x00\x00\x00")}) } - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWSET), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -342,7 +342,7 @@ func (cc *Conn) AddSet(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | hdrType), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -362,7 +362,7 @@ func (cc *Conn) DelSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSET), Flags: netlink.Request | netlink.Acknowledge, @@ -383,7 +383,7 @@ func (cc *Conn) SetDeleteElements(s *Set, vals []SetElement) error { if err != nil { return err } - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -402,7 +402,7 @@ func (cc *Conn) FlushSet(s *Set) { {Type: unix.NFTA_SET_TABLE, Data: []byte(s.Table.Name + "\x00")}, {Type: unix.NFTA_SET_NAME, Data: []byte(s.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELSETELEM), Flags: netlink.Request | netlink.Acknowledge, diff --git a/table.go b/table.go index da0126a..08c83f7 100644 --- a/table.go +++ b/table.go @@ -53,7 +53,7 @@ func (cc *Conn) DelTable(t *Table) { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELTABLE), Flags: netlink.Request | netlink.Acknowledge, @@ -71,7 +71,7 @@ func (cc *Conn) AddTable(t *Table) *Table { {Type: unix.NFTA_TABLE_NAME, Data: []byte(t.Name + "\x00")}, {Type: unix.NFTA_TABLE_FLAGS, Data: []byte{0, 0, 0, 0}}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_NEWTABLE), Flags: netlink.Request | netlink.Acknowledge | netlink.Create, @@ -89,7 +89,7 @@ func (cc *Conn) FlushTable(t *Table) { data := cc.marshalAttr([]netlink.Attribute{ {Type: unix.NFTA_RULE_TABLE, Data: []byte(t.Name + "\x00")}, }) - cc.messages = append(cc.messages, netlink.Message{ + cc.putMessage(netlink.Message{ Header: netlink.Header{ Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_DELRULE), Flags: netlink.Request | netlink.Acknowledge,