Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List single table or chain by name #258

Merged
merged 1 commit into from
Apr 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,43 @@ func (cc *Conn) ListChains() ([]*Chain, error) {
return cc.ListChainsOfTableFamily(TableFamilyUnspecified)
}

// ListChain returns a single chain configured in the specified table
func (cc *Conn) ListChain(table *Table, chain string) (*Chain, error) {
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer func() { _ = closer() }()

attrs := []netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(table.Name + "\x00")},
{Type: unix.NFTA_CHAIN_NAME, Data: []byte(chain + "\x00")},
}
msg := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETCHAIN),
Flags: netlink.Request,
},
Data: append(extraHeader(uint8(table.Family), 0), cc.marshalAttr(attrs)...),
}

response, err := conn.Execute(msg)
if err != nil {
return nil, fmt.Errorf("conn.Execute failed: %v", err)
}

if got, want := len(response), 1; got != want {
return nil, fmt.Errorf("expected %d response message for chain, got %d", want, got)
}

ch, err := chainFromMsg(response[0])
if err != nil {
return nil, err
}

return ch, nil
}

// ListChainsOfTableFamily returns currently configured chains for the specified
// family in the kernel. It lists all chains ins all tables if family is
// TableFamilyUnspecified.
Expand Down
154 changes: 154 additions & 0 deletions nftables_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1746,6 +1746,160 @@ func TestListChains(t *testing.T) {
}
}

func TestListChainByName(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()

table := &nftables.Table{
Name: "chain_test",
Family: nftables.TableFamilyIPv4,
}
tr := conn.AddTable(table)

c := &nftables.Chain{
Name: "filter",
Table: table,
}
conn.AddChain(c)

if err := conn.Flush(); err != nil {
t.Errorf("conn.Flush() failed: %v", err)
}

cr, err := conn.ListChain(tr, c.Name)
if err != nil {
t.Fatalf("conn.ListChain() failed: %v", err)
}

if got, want := cr.Name, c.Name; got != want {
t.Fatalf("got chain %s, want chain %s", got, want)
}

if got, want := cr.Table.Name, table.Name; got != want {
t.Fatalf("got chain table %s, want chain table %s", got, want)
}
}

func TestListChainByNameUsingLasting(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
conn, err := nftables.New(nftables.WithNetNSFd(int(newNS)), nftables.AsLasting())
if err != nil {
t.Fatalf("nftables.New() failed: %v", err)
}
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()

table := &nftables.Table{
Name: "chain_test_lasting",
Family: nftables.TableFamilyIPv4,
}
tr := conn.AddTable(table)

c := &nftables.Chain{
Name: "filter_lasting",
Table: table,
}
conn.AddChain(c)

if err := conn.Flush(); err != nil {
t.Errorf("conn.Flush() failed: %v", err)
}

cr, err := conn.ListChain(tr, c.Name)
if err != nil {
t.Fatalf("conn.ListChain() failed: %v", err)
}

if got, want := cr.Name, c.Name; got != want {
t.Fatalf("got chain %s, want chain %s", got, want)
}

if got, want := cr.Table.Name, table.Name; got != want {
t.Fatalf("got chain table %s, want chain table %s", got, want)
}
}

func TestListTableByName(t *testing.T) {
conn, newNS := nftest.OpenSystemConn(t, *enableSysTests)
defer nftest.CleanupSystemConn(t, newNS)
conn.FlushRuleset()
defer conn.FlushRuleset()

table1 := &nftables.Table{
Name: "table_test",
Family: nftables.TableFamilyIPv4,
}
conn.AddTable(table1)
table2 := &nftables.Table{
Name: "table_test_inet",
Family: nftables.TableFamilyINet,
}
conn.AddTable(table2)
table3 := &nftables.Table{
Name: table1.Name,
Family: nftables.TableFamilyINet,
}
conn.AddTable(table3)

if err := conn.Flush(); err != nil {
t.Errorf("conn.Flush() failed: %v", err)
}

tr, err := conn.ListTable(table1.Name)
if err != nil {
t.Fatalf("conn.ListTable() failed: %v", err)
}

if got, want := tr.Name, table1.Name; got != want {
t.Fatalf("got table %s, want table %s", got, want)
}

// not specifying table family should return family ipv4
tr, err = conn.ListTable(table3.Name)
if err != nil {
t.Fatalf("conn.ListTable() failed: %v", err)
}
if got, want := tr.Name, table1.Name; got != want {
t.Fatalf("got table %s, want table %s", got, want)
}
if got, want := tr.Family, table1.Family; got != want {
t.Fatalf("got table family %v, want table family %v", got, want)
}

// specifying correct INet family
tr, err = conn.ListTableOfFamily(table3.Name, nftables.TableFamilyINet)
if err != nil {
t.Fatalf("conn.ListTable() failed: %v", err)
}
if got, want := tr.Name, table3.Name; got != want {
t.Fatalf("got table %s, want table %s", got, want)
}
if got, want := tr.Family, table3.Family; got != want {
t.Fatalf("got table family %v, want table family %v", got, want)
}

// not specifying correct family should return err since no table in ipv4
tr, err = conn.ListTable(table2.Name)
if err == nil {
t.Fatalf("conn.ListTable() should have failed")
}

// specifying correct INet family
tr, err = conn.ListTableOfFamily(table2.Name, nftables.TableFamilyINet)
if err != nil {
t.Fatalf("conn.ListTable() failed: %v", err)
}
if got, want := tr.Name, table2.Name; got != want {
t.Fatalf("got table %s, want table %s", got, want)
}
if got, want := tr.Family, table2.Family; got != want {
t.Fatalf("got table family %v, want table family %v", got, want)
}
}

func TestAddChain(t *testing.T) {
tests := []struct {
name string
Expand Down
36 changes: 34 additions & 2 deletions table.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,25 @@ func (cc *Conn) FlushTable(t *Table) {
})
}

// ListTable returns table found for the specified name. Searches for
// the table under IPv4 family. As per nft man page: "When no address
// family is specified, ip is used by default."
func (cc *Conn) ListTable(name string) (*Table, error) {
return cc.ListTableOfFamily(name, TableFamilyIPv4)
}

// ListTableOfFamily returns table found for the specified name and table family
func (cc *Conn) ListTableOfFamily(name string, family TableFamily) (*Table, error) {
t, err := cc.listTablesOfNameAndFamily(name, family)
if err != nil {
return nil, err
}
if got, want := len(t), 1; got != want {
return nil, fmt.Errorf("expected table count %d, got %d", want, got)
}
return t[0], nil
}

// ListTables returns currently configured tables in the kernel
func (cc *Conn) ListTables() ([]*Table, error) {
return cc.ListTablesOfFamily(TableFamilyUnspecified)
Expand All @@ -120,18 +139,31 @@ func (cc *Conn) ListTables() ([]*Table, error) {
// ListTablesOfFamily returns currently configured tables for the specified table family
// in the kernel. It lists all tables if family is TableFamilyUnspecified.
func (cc *Conn) ListTablesOfFamily(family TableFamily) ([]*Table, error) {
return cc.listTablesOfNameAndFamily("", family)
}

func (cc *Conn) listTablesOfNameAndFamily(name string, family TableFamily) ([]*Table, error) {
conn, closer, err := cc.netlinkConn()
if err != nil {
return nil, err
}
defer func() { _ = closer() }()

data := extraHeader(uint8(family), 0)
flags := netlink.Request | netlink.Dump
if name != "" {
data = append(data, cc.marshalAttr([]netlink.Attribute{
{Type: unix.NFTA_TABLE_NAME, Data: []byte(name + "\x00")},
})...)
flags = netlink.Request
}

msg := netlink.Message{
Header: netlink.Header{
Type: netlink.HeaderType((unix.NFNL_SUBSYS_NFTABLES << 8) | unix.NFT_MSG_GETTABLE),
Flags: netlink.Request | netlink.Dump,
Flags: flags,
},
Data: extraHeader(uint8(family), 0),
Data: data,
}

response, err := conn.Execute(msg)
Expand Down
Loading