diff --git a/protocols/bgp/server/bgp_api_test.go b/protocols/bgp/server/bgp_api_test.go index b7c282493..4113cc6ab 100644 --- a/protocols/bgp/server/bgp_api_test.go +++ b/protocols/bgp/server/bgp_api_test.go @@ -15,6 +15,7 @@ import ( "github.com/bio-routing/bio-rd/routingtable/adjRIBIn" "github.com/bio-routing/bio-rd/routingtable/adjRIBOut" "github.com/bio-routing/bio-rd/routingtable/filter" + "github.com/bio-routing/bio-rd/routingtable/vrf" "github.com/stretchr/testify/assert" "google.golang.org/grpc" "google.golang.org/grpc/test/bufconn" @@ -30,6 +31,8 @@ func TestDumpRIBInOut(t *testing.T) { AddPathTX: true, } + vrf, _ := vrf.New("vrf0", 0) + tests := []struct { name string apisrv *BGPAPIServer @@ -66,7 +69,7 @@ func TestDumpRIBInOut(t *testing.T) { fsms: []*FSM{ 0: { ipv4Unicast: &fsmAddressFamily{ - adjRIBIn: adjRIBIn.New(filter.NewAcceptAllFilterChain(), nil, sessionAttrs), + adjRIBIn: adjRIBIn.New(filter.NewAcceptAllFilterChain(), vrf, sessionAttrs), adjRIBOut: adjRIBOut.New(nil, routingtable.SessionAttrs{Type: route.BGPPathType}, filter.NewAcceptAllFilterChain()), }, }, @@ -96,7 +99,7 @@ func TestDumpRIBInOut(t *testing.T) { fsms: []*FSM{ 0: { ipv4Unicast: &fsmAddressFamily{ - adjRIBIn: adjRIBIn.New(filter.NewAcceptAllFilterChain(), nil, sessionAttrs), + adjRIBIn: adjRIBIn.New(filter.NewAcceptAllFilterChain(), vrf, sessionAttrs), adjRIBOut: adjRIBOut.New(nil, routingtable.SessionAttrs{Type: route.BGPPathType, RouteServerClient: true, PeerIP: bnet.IPv4(0).Ptr()}, filter.NewAcceptAllFilterChain()), }, }, @@ -156,7 +159,7 @@ func TestDumpRIBInOut(t *testing.T) { fsms: []*FSM{ 0: { ipv4Unicast: &fsmAddressFamily{ - adjRIBIn: adjRIBIn.New(filter.NewAcceptAllFilterChain(), routingtable.NewContributingASNs(), sessionAttrs), + adjRIBIn: adjRIBIn.New(filter.NewAcceptAllFilterChain(), vrf, sessionAttrs), adjRIBOut: adjRIBOut.New(nil, routingtable.SessionAttrs{Type: route.BGPPathType, RouteServerClient: true, PeerIP: bnet.IPv4(123).Ptr()}, filter.NewAcceptAllFilterChain()), }, }, diff --git a/protocols/bgp/server/fsm_address_family.go b/protocols/bgp/server/fsm_address_family.go index d84ab9809..a516931a6 100644 --- a/protocols/bgp/server/fsm_address_family.go +++ b/protocols/bgp/server/fsm_address_family.go @@ -13,6 +13,7 @@ import ( "github.com/bio-routing/bio-rd/routingtable/adjRIBOut" "github.com/bio-routing/bio-rd/routingtable/filter" "github.com/bio-routing/bio-rd/routingtable/locRIB" + "github.com/bio-routing/bio-rd/routingtable/vrf" ) // fsmAddressFamily holds RIBs and the UpdateSender of an peer for an AFI/SAFI combination @@ -24,6 +25,7 @@ type fsmAddressFamily struct { adjRIBIn routingtable.AdjRIBIn adjRIBOut routingtable.AdjRIBOut rib *locRIB.LocRIB + vrf *vrf.VRF importFilterChain filter.Chain exportFilterChain filter.Chain @@ -44,6 +46,7 @@ func newFSMAddressFamily(afi uint16, safi uint8, family *peerAddressFamily, fsm afi: afi, safi: safi, fsm: fsm, + vrf: fsm.peer.vrf, rib: family.rib, importFilterChain: family.importFilterChain, exportFilterChain: family.exportFilterChain, @@ -80,21 +83,20 @@ func (f *fsmAddressFamily) dumpRIBIn() []*route.Route { } type adjRIBInFactoryI interface { - New(exportFilterChain filter.Chain, contributingASNs *routingtable.ContributingASNs, sessionAttrs routingtable.SessionAttrs) routingtable.AdjRIBIn + New(exportFilterChain filter.Chain, vrf *vrf.VRF, sessionAttrs routingtable.SessionAttrs) routingtable.AdjRIBIn } type adjRIBInFactory struct{} -func (a adjRIBInFactory) New(exportFilterChain filter.Chain, contributingASNs *routingtable.ContributingASNs, sessionAttrs routingtable.SessionAttrs) routingtable.AdjRIBIn { - return adjRIBIn.New(exportFilterChain, contributingASNs, sessionAttrs) +func (a adjRIBInFactory) New(exportFilterChain filter.Chain, vrf *vrf.VRF, sessionAttrs routingtable.SessionAttrs) routingtable.AdjRIBIn { + return adjRIBIn.New(exportFilterChain, vrf, sessionAttrs) } func (f *fsmAddressFamily) init() { - contributingASNs := f.rib.GetContributingASNs() sessionAttrs := f.getSessionAttrs() - f.adjRIBIn = f.fsm.peer.adjRIBInFactory.New(f.importFilterChain, contributingASNs, sessionAttrs) - contributingASNs.Add(f.fsm.peer.localASN) + f.adjRIBIn = f.fsm.peer.adjRIBInFactory.New(f.importFilterChain, f.vrf, sessionAttrs) + f.vrf.AddContributingASN(f.fsm.peer.localASN) f.adjRIBIn.Register(f.rib) @@ -138,7 +140,7 @@ func (f *fsmAddressFamily) getSessionAttrs() routingtable.SessionAttrs { } func (f *fsmAddressFamily) bmpInit() { - f.adjRIBIn = f.fsm.peer.adjRIBInFactory.New(filter.NewAcceptAllFilterChain(), &routingtable.ContributingASNs{}, f.getSessionAttrs()) + f.adjRIBIn = f.fsm.peer.adjRIBInFactory.New(filter.NewAcceptAllFilterChain(), f.fsm.peer.vrf, f.getSessionAttrs()) if f.rib != nil { f.adjRIBIn.Register(f.rib) @@ -148,7 +150,7 @@ func (f *fsmAddressFamily) bmpInit() { } func (f *fsmAddressFamily) bmpDispose() { - f.rib.GetContributingASNs().Remove(f.fsm.peer.localASN) + f.vrf.RemoveContributingASN(f.fsm.peer.localASN) f.adjRIBIn.Flush() @@ -162,7 +164,7 @@ func (f *fsmAddressFamily) dispose() { return } - f.rib.GetContributingASNs().Remove(f.fsm.peer.localASN) + f.vrf.RemoveContributingASN(f.fsm.peer.localASN) f.adjRIBIn.Unregister(f.rib) f.rib.Unregister(f.adjRIBOut) f.adjRIBOut.Unregister(f.updateSender) diff --git a/protocols/bgp/server/fsm_address_family_test.go b/protocols/bgp/server/fsm_address_family_test.go index e40a927ef..0d01a9b7f 100644 --- a/protocols/bgp/server/fsm_address_family_test.go +++ b/protocols/bgp/server/fsm_address_family_test.go @@ -10,6 +10,7 @@ import ( "github.com/bio-routing/bio-rd/routingtable" "github.com/bio-routing/bio-rd/routingtable/filter" "github.com/bio-routing/bio-rd/routingtable/locRIB" + "github.com/bio-routing/bio-rd/routingtable/vrf" "github.com/stretchr/testify/assert" biotesting "github.com/bio-routing/bio-rd/testing" @@ -20,6 +21,7 @@ func TestFSMAFIInitDispose(t *testing.T) { afi: packet.AFIIPv4, safi: packet.SAFIUnicast, rib: locRIB.New("inet.0"), + vrf: vrf.NewUntrackedVRF("vrf0", 0), importFilterChain: filter.NewAcceptAllFilterChain(), exportFilterChain: filter.NewAcceptAllFilterChain(), fsm: &FSM{ @@ -41,8 +43,8 @@ func TestFSMAFIInitDispose(t *testing.T) { f.init() assert.NotEqual(t, nil, f.adjRIBIn) - assert.Equal(t, true, f.rib.GetContributingASNs().IsContributingASN(15169)) - assert.NotEqual(t, true, f.rib.GetContributingASNs().IsContributingASN(15170)) + assert.Equal(t, true, f.vrf.IsContributingASN(15169)) + assert.NotEqual(t, true, f.vrf.IsContributingASN(15170)) assert.NotEqual(t, nil, f.adjRIBOut) assert.NotEqual(t, nil, f.updateSender) @@ -57,7 +59,7 @@ func TestFSMAFIInitDispose(t *testing.T) { f.dispose() f.updateSender.wg.Wait() - assert.Equal(t, false, f.rib.GetContributingASNs().IsContributingASN(15169)) + assert.Equal(t, false, f.vrf.IsContributingASN(15169)) assert.Equal(t, uint64(0), f.rib.ClientCount()) assert.Equal(t, nil, f.adjRIBOut) assert.Equal(t, false, f.initialized) diff --git a/protocols/bgp/server/fsm_test.go b/protocols/bgp/server/fsm_test.go index 324f16ef4..379affe2b 100644 --- a/protocols/bgp/server/fsm_test.go +++ b/protocols/bgp/server/fsm_test.go @@ -9,6 +9,7 @@ import ( "github.com/bio-routing/bio-rd/protocols/bgp/packet" "github.com/bio-routing/bio-rd/routingtable/filter" "github.com/bio-routing/bio-rd/routingtable/locRIB" + "github.com/bio-routing/bio-rd/routingtable/vrf" "github.com/stretchr/testify/assert" ) @@ -23,6 +24,7 @@ func TestFSM255UpdatesIPv4(t *testing.T) { exportFilterChain: filter.NewAcceptAllFilterChain(), }, adjRIBInFactory: adjRIBInFactory{}, + vrf: vrf.NewUntrackedVRF("vrf0", 0), }) fsmA.holdTime = time.Second * 180 @@ -142,6 +144,7 @@ func TestFSM255UpdatesIPv6(t *testing.T) { exportFilterChain: filter.NewAcceptAllFilterChain(), }, adjRIBInFactory: adjRIBInFactory{}, + vrf: vrf.NewUntrackedVRF("vrf0", 0), }) fsmA.ipv6Unicast.multiProtocol = true diff --git a/protocols/bgp/server/metrics_service_test.go b/protocols/bgp/server/metrics_service_test.go index 5219515b0..c7dcf1a3c 100644 --- a/protocols/bgp/server/metrics_service_test.go +++ b/protocols/bgp/server/metrics_service_test.go @@ -4,17 +4,17 @@ import ( "testing" "time" - "github.com/bio-routing/bio-rd/protocols/bgp/packet" "github.com/bio-routing/bio-rd/routingtable" "github.com/bio-routing/bio-rd/routingtable/vrf" "github.com/stretchr/testify/assert" bnet "github.com/bio-routing/bio-rd/net" "github.com/bio-routing/bio-rd/protocols/bgp/metrics" + "github.com/bio-routing/bio-rd/protocols/bgp/packet" ) func TestMetrics(t *testing.T) { - vrf, _ := vrf.New("inet.0", 0) + vrf := vrf.NewUntrackedVRF("inet.0", 0) establishedTime := time.Now() tests := []struct { diff --git a/routingtable/adjRIBIn/adj_rib_in.go b/routingtable/adjRIBIn/adj_rib_in.go index 9c8eeedc1..caeb16b9b 100644 --- a/routingtable/adjRIBIn/adj_rib_in.go +++ b/routingtable/adjRIBIn/adj_rib_in.go @@ -8,6 +8,7 @@ import ( "github.com/bio-routing/bio-rd/route" "github.com/bio-routing/bio-rd/routingtable" "github.com/bio-routing/bio-rd/routingtable/filter" + "github.com/bio-routing/bio-rd/routingtable/vrf" "github.com/bio-routing/bio-rd/util/log" ) @@ -17,16 +18,16 @@ type AdjRIBIn struct { rt *routingtable.RoutingTable mu sync.RWMutex exportFilterChain filter.Chain - contributingASNs *routingtable.ContributingASNs + vrf *vrf.VRF sessionAttrs routingtable.SessionAttrs } // New creates a new Adjacency RIB In -func New(exportFilterChain filter.Chain, contributingASNs *routingtable.ContributingASNs, sessionAttrs routingtable.SessionAttrs) *AdjRIBIn { +func New(exportFilterChain filter.Chain, vrf *vrf.VRF, sessionAttrs routingtable.SessionAttrs) *AdjRIBIn { a := &AdjRIBIn{ rt: routingtable.NewRoutingTable(), exportFilterChain: exportFilterChain, - contributingASNs: contributingASNs, + vrf: vrf, sessionAttrs: sessionAttrs, } a.clientManager = routingtable.NewClientManager(a) @@ -319,7 +320,7 @@ func (a *AdjRIBIn) ourASNsInPath(p *route.Path) bool { for _, pathSegment := range *p.BGPPath.ASPath { for _, asn := range pathSegment.ASNs { - if a.contributingASNs.IsContributingASN(asn) { + if a.vrf.IsContributingASN(asn) { return true } } diff --git a/routingtable/adjRIBIn/adj_rib_in_test.go b/routingtable/adjRIBIn/adj_rib_in_test.go index 432e70363..2346dd0a2 100644 --- a/routingtable/adjRIBIn/adj_rib_in_test.go +++ b/routingtable/adjRIBIn/adj_rib_in_test.go @@ -9,6 +9,7 @@ import ( "github.com/bio-routing/bio-rd/route" "github.com/bio-routing/bio-rd/routingtable" "github.com/bio-routing/bio-rd/routingtable/filter" + "github.com/bio-routing/bio-rd/routingtable/vrf" "github.com/stretchr/testify/assert" ) @@ -265,7 +266,7 @@ func TestAddPath(t *testing.T) { ClusterID: clusterID, AddPathRX: test.addPath, } - adjRIBIn := New(filter.NewAcceptAllFilterChain(), routingtable.NewContributingASNs(), sessionAttrs) + adjRIBIn := New(filter.NewAcceptAllFilterChain(), vrf.NewUntrackedVRF("vrf0", 0), sessionAttrs) mc := routingtable.NewRTMockClient() adjRIBIn.clientManager.RegisterWithOptions(mc, routingtable.ClientOptions{BestOnly: true}) @@ -495,7 +496,7 @@ func TestRemovePath(t *testing.T) { ClusterID: 2, AddPathRX: test.addPath, } - adjRIBIn := New(filter.NewAcceptAllFilterChain(), routingtable.NewContributingASNs(), sessionAttrs) + adjRIBIn := New(filter.NewAcceptAllFilterChain(), vrf.NewUntrackedVRF("vrf0", 0), sessionAttrs) for _, route := range test.routes { adjRIBIn.AddPath(route.Prefix().Ptr(), route.Paths()[0]) } @@ -523,7 +524,7 @@ func TestRemovePath(t *testing.T) { } func TestUnregister(t *testing.T) { - adjRIBIn := New(filter.NewAcceptAllFilterChain(), routingtable.NewContributingASNs(), routingtable.SessionAttrs{}) + adjRIBIn := New(filter.NewAcceptAllFilterChain(), vrf.NewUntrackedVRF("vrf0", 0), routingtable.SessionAttrs{}) mc := routingtable.NewRTMockClient() adjRIBIn.Register(mc) @@ -839,7 +840,7 @@ func TestPeerRoleOTC(t *testing.T) { } for _, test := range tests { - adjRIBIn := New(filter.NewAcceptAllFilterChain(), routingtable.NewContributingASNs(), test.sessionAttrs) + adjRIBIn := New(filter.NewAcceptAllFilterChain(), vrf.NewUntrackedVRF("vrf0", 0), test.sessionAttrs) mc := routingtable.NewRTMockClient() adjRIBIn.clientManager.RegisterWithOptions(mc, routingtable.ClientOptions{BestOnly: true}) diff --git a/routingtable/contributing_asn_list.go b/routingtable/contributing_asn_list.go deleted file mode 100644 index d0c7f4a26..000000000 --- a/routingtable/contributing_asn_list.go +++ /dev/null @@ -1,88 +0,0 @@ -package routingtable - -import ( - "fmt" - "math" - "sync" -) - -type contributingASN struct { - asn uint32 - count uint32 -} - -// ContributingASNs contains a list of contributing ASN to a LocRIB to check ASPaths for possible routing loops. -type ContributingASNs struct { - contributingASNs []*contributingASN - contributingASNsLock sync.RWMutex -} - -// NewContributingASNs creates a list of contributing ASNs to a LocRIB for routing loop prevention. -func NewContributingASNs() *ContributingASNs { - c := &ContributingASNs{ - contributingASNs: []*contributingASN{}, - } - - return c -} - -// Add a new ASN to the list of contributing ASNs or add the ref count of an existing one. -func (c *ContributingASNs) Add(asn uint32) { - c.contributingASNsLock.Lock() - defer c.contributingASNsLock.Unlock() - - for _, cASN := range c.contributingASNs { - if cASN.asn == asn { - cASN.count++ - - if cASN.count == math.MaxUint32 { - panic(fmt.Sprintf("Contributing ASNs counter overflow triggered for AS %d. Dying of shame.", asn)) - } - - return - } - } - - c.contributingASNs = append(c.contributingASNs, &contributingASN{ - asn: asn, - count: 1, - }) -} - -// Remove a ASN to the list of contributing ASNs or decrement the ref count of an existing one. -func (c *ContributingASNs) Remove(asn uint32) { - c.contributingASNsLock.Lock() - defer c.contributingASNsLock.Unlock() - - asnList := c.contributingASNs - - for i, cASN := range asnList { - if cASN.asn != asn { - continue - } - - cASN.count-- - - if cASN.count == 0 { - copy(asnList[i:], asnList[i+1:]) - asnList = asnList[:] - c.contributingASNs = asnList[:len(asnList)-1] - } - - return - } -} - -// IsContributingASN checks if a given ASN is part of the contributing ASNs -func (c *ContributingASNs) IsContributingASN(asn uint32) bool { - c.contributingASNsLock.RLock() - defer c.contributingASNsLock.RUnlock() - - for _, cASN := range c.contributingASNs { - if asn == cASN.asn { - return true - } - } - - return false -} diff --git a/routingtable/contributing_asn_list_test.go b/routingtable/contributing_asn_list_test.go deleted file mode 100644 index 6387b18bc..000000000 --- a/routingtable/contributing_asn_list_test.go +++ /dev/null @@ -1,90 +0,0 @@ -package routingtable - -import ( - "fmt" - "testing" -) - -func TestFancy(t *testing.T) { - c := NewContributingASNs() - - tests := []struct { - runCmd func() - expect func() bool - msg string - }{ - // Empty list - { - runCmd: func() {}, - expect: func() bool { return !c.IsContributingASN(41981) }, - msg: "AS41981 shouldn't be contributing yet.", - }, - - // Add and remove one ASN - { - runCmd: func() { c.Add(41981) }, - expect: func() bool { return c.IsContributingASN(41981) }, - msg: "AS41981 should be contributing.", - }, - { - runCmd: func() { c.Remove(41981) }, - expect: func() bool { return !c.IsContributingASN(41981) }, - msg: "AS41981 shouldn't be contributing no more.", - }, - - // Two ASNs present - { - runCmd: func() { c.Add(41981) }, - expect: func() bool { return c.IsContributingASN(41981) }, - msg: "AS41981 should be contributing.", - }, - { - runCmd: func() { c.Add(201701) }, - expect: func() bool { return c.IsContributingASN(41981) }, - msg: "AS201701 should be contributing.", - }, - - // Add AS41981 2nd time - { - runCmd: func() { c.Add(41981) }, - expect: func() bool { return c.IsContributingASN(41981) }, - msg: "AS41981 should be still contributing.", - }, - { - runCmd: func() {}, - expect: func() bool { return c.contributingASNs[0].asn == 41981 }, - msg: "AS41981 is first ASN in list.", - }, - { - runCmd: func() { fmt.Printf("%+v", c.contributingASNs) }, - expect: func() bool { return c.contributingASNs[0].count == 2 }, - msg: "AS41981 should be present twice.", - }, - - // Remove 2nd AS41981 - { - runCmd: func() { c.Remove(41981) }, - expect: func() bool { return c.IsContributingASN(41981) }, - msg: "AS41981 should still be contributing.", - }, - { - runCmd: func() {}, - expect: func() bool { return c.contributingASNs[0].count == 1 }, - msg: "S41981 should be present once.", - }, - - // Remove AS201701 - { - runCmd: func() { c.Remove(201701) }, - expect: func() bool { return !c.IsContributingASN(201701) }, - msg: "AS201701 shouldn't be contributing no more.", - }, - } - - for i, test := range tests { - test.runCmd() - if !test.expect() { - t.Errorf("Test %d failed: %v", i, test.msg) - } - } -} diff --git a/routingtable/locRIB/loc_rib.go b/routingtable/locRIB/loc_rib.go index a9fc4c858..d77e19db5 100644 --- a/routingtable/locRIB/loc_rib.go +++ b/routingtable/locRIB/loc_rib.go @@ -14,12 +14,11 @@ import ( // LocRIB represents a routing information base type LocRIB struct { - name string - clientManager *routingtable.ClientManager - rt *routingtable.RoutingTable - mu sync.RWMutex - contributingASNs *routingtable.ContributingASNs - countTarget *countTarget + name string + clientManager *routingtable.ClientManager + rt *routingtable.RoutingTable + mu sync.RWMutex + countTarget *countTarget } type countTarget struct { @@ -30,9 +29,8 @@ type countTarget struct { // New creates a new routing information base func New(name string) *LocRIB { a := &LocRIB{ - name: name, - rt: routingtable.NewRoutingTable(), - contributingASNs: routingtable.NewContributingASNs(), + name: name, + rt: routingtable.NewRoutingTable(), } a.clientManager = routingtable.NewClientManager(a) @@ -51,11 +49,6 @@ func (a *LocRIB) ClientCount() uint64 { return a.clientManager.ClientCount() } -// GetContributingASNs returns a pointer to the list of contributing ASNs -func (a *LocRIB) GetContributingASNs() *routingtable.ContributingASNs { - return a.contributingASNs -} - // Count routes from the LocRIB func (a *LocRIB) Count() uint64 { return uint64(a.rt.GetRouteCount()) diff --git a/routingtable/vrf/vrf.go b/routingtable/vrf/vrf.go index f5b5f735f..0e9731d54 100644 --- a/routingtable/vrf/vrf.go +++ b/routingtable/vrf/vrf.go @@ -7,6 +7,7 @@ import ( "sync" "github.com/bio-routing/bio-rd/routingtable/locRIB" + "github.com/bio-routing/bio-rd/util/refcounter" ) const ( @@ -27,11 +28,12 @@ type VRF struct { ribs map[addressFamily]*locRIB.LocRIB mu sync.Mutex ribNames map[string]*locRIB.LocRIB + contributingASNs *refcounter.RefcounterUint32 } // New creates a new VRF. The VRF is registered automatically to the global VRF registry. func New(name string, rd uint64) (*VRF, error) { - v := newUntrackedVRF(name, rd) + v := NewUntrackedVRF(name, rd) v.CreateIPv4UnicastLocRIB("inet.0") v.CreateIPv6UnicastLocRIB("inet6.0") @@ -43,12 +45,13 @@ func New(name string, rd uint64) (*VRF, error) { return v, nil } -func newUntrackedVRF(name string, rd uint64) *VRF { +func NewUntrackedVRF(name string, rd uint64) *VRF { return &VRF{ name: name, routeDistinguisher: rd, ribs: make(map[addressFamily]*locRIB.LocRIB), ribNames: make(map[string]*locRIB.LocRIB), + contributingASNs: refcounter.NewRefCounterUint32(), } } @@ -119,6 +122,19 @@ func (v *VRF) RIBByName(name string) (rib *locRIB.LocRIB, found bool) { return rib, found } +func (v *VRF) AddContributingASN(asn uint32) { + v.contributingASNs.Add(asn) +} + +func (v *VRF) RemoveContributingASN(asn uint32) { + v.contributingASNs.Remove(asn) +} + +// IsContributingASN returns wether the given ASN is used by this BGP speaker somewhere within this VRF +func (v *VRF) IsContributingASN(asn uint32) bool { + return v.contributingASNs.IsPresent(asn) +} + func (v *VRF) nameForRIB(rib *locRIB.LocRIB) string { for name, r := range v.ribNames { if r == rib { diff --git a/routingtable/vrf/vrf_registry.go b/routingtable/vrf/vrf_registry.go index ed0f409a8..e30a3f28a 100644 --- a/routingtable/vrf/vrf_registry.go +++ b/routingtable/vrf/vrf_registry.go @@ -31,7 +31,7 @@ func (r *VRFRegistry) CreateVRFIfNotExists(name string, rd uint64) *VRF { return r.vrfs[rd] } - r.vrfs[rd] = newUntrackedVRF(name, rd) + r.vrfs[rd] = NewUntrackedVRF(name, rd) r.vrfs[rd].CreateIPv4UnicastLocRIB("inet.0") r.vrfs[rd].CreateIPv6UnicastLocRIB("inet6.0") return r.vrfs[rd] diff --git a/routingtable/vrf/vrf_test.go b/routingtable/vrf/vrf_test.go index 32422f7a4..73d12b772 100644 --- a/routingtable/vrf/vrf_test.go +++ b/routingtable/vrf/vrf_test.go @@ -15,7 +15,7 @@ func TestNewWithDuplicate(t *testing.T) { } func TestIPv4UnicastRIBWith(t *testing.T) { - v := newUntrackedVRF("master", 0) + v := NewUntrackedVRF("master", 0) rib, err := v.CreateIPv4UnicastLocRIB("inet.0") assert.Equal(t, rib, v.IPv4UnicastRIB()) @@ -23,7 +23,7 @@ func TestIPv4UnicastRIBWith(t *testing.T) { } func TestIPv6UnicastRIB(t *testing.T) { - v := newUntrackedVRF("master", 0) + v := NewUntrackedVRF("master", 0) rib, err := v.CreateIPv6UnicastLocRIB("inet6.0") assert.Equal(t, rib, v.IPv6UnicastRIB()) @@ -31,7 +31,7 @@ func TestIPv6UnicastRIB(t *testing.T) { } func TestCreateLocRIBTwice(t *testing.T) { - v := newUntrackedVRF("master", 0) + v := NewUntrackedVRF("master", 0) _, err := v.CreateIPv6UnicastLocRIB("inet6.0") assert.Nil(t, err, "error must be nil on first invokation") @@ -40,7 +40,7 @@ func TestCreateLocRIBTwice(t *testing.T) { } func TestRIBByName(t *testing.T) { - v := newUntrackedVRF("master", 0) + v := NewUntrackedVRF("master", 0) rib, _ := v.CreateIPv6UnicastLocRIB("inet6.0") assert.NotNil(t, rib, "rib must not be nil after creation") @@ -50,7 +50,7 @@ func TestRIBByName(t *testing.T) { } func TestName(t *testing.T) { - v := newUntrackedVRF("foo", 0) + v := NewUntrackedVRF("foo", 0) assert.Equal(t, "foo", v.Name()) } diff --git a/util/refcounter/refcounter_uint32.go b/util/refcounter/refcounter_uint32.go new file mode 100644 index 000000000..92b6f668a --- /dev/null +++ b/util/refcounter/refcounter_uint32.go @@ -0,0 +1,88 @@ +package refcounter + +import ( + "fmt" + "math" + "sync" +) + +type item struct { + value uint32 + count uint32 +} + +// RefcounterUint32 contains a list of items to keep refcounts on +type RefcounterUint32 struct { + items []*item + itemsMu sync.RWMutex +} + +// NewRefCounterUint32 creates a list of items to keep refcounts on +func NewRefCounterUint32() *RefcounterUint32 { + c := &RefcounterUint32{ + items: []*item{}, + } + + return c +} + +// Add adds new item to the list of items or add the ref count of an existing one. +func (r *RefcounterUint32) Add(value uint32) { + r.itemsMu.Lock() + defer r.itemsMu.Unlock() + + for _, iterItem := range r.items { + if iterItem.value == value { + iterItem.count++ + + if iterItem.count == math.MaxUint32 { + panic(fmt.Sprintf("Counter overflow triggered for item %d. Dying of shame.", value)) + } + + return + } + } + + r.items = append(r.items, &item{ + value: value, + count: 1, + }) +} + +// Remove a value from the list of items or decrement the ref count of an existing one. +func (r *RefcounterUint32) Remove(value uint32) { + r.itemsMu.Lock() + defer r.itemsMu.Unlock() + + itemList := r.items + + for i, iterItem := range itemList { + if iterItem.value != value { + continue + } + + iterItem.count-- + + if iterItem.count == 0 { + copy(itemList[i:], itemList[i+1:]) + itemList = itemList[:] + r.items = itemList[:len(itemList)-1] + } + + return + } +} + +// IsPresent checks if a given value is part of the known items +func (r *RefcounterUint32) IsPresent(value uint32) bool { + r.itemsMu.RLock() + defer r.itemsMu.RUnlock() + + for _, iterItem := range r.items { + if value == iterItem.value { + return true + } + } + + return false +} diff --git a/util/refcounter/refcounter_uint32_test.go b/util/refcounter/refcounter_uint32_test.go new file mode 100644 index 000000000..2bdeb0fc1 --- /dev/null +++ b/util/refcounter/refcounter_uint32_test.go @@ -0,0 +1,90 @@ +package refcounter + +import ( + "fmt" + "testing" +) + +func TestFancy(t *testing.T) { + r := NewRefCounterUint32() + + tests := []struct { + runCmd func() + expect func() bool + msg string + }{ + // Empty list + { + runCmd: func() {}, + expect: func() bool { return !r.IsPresent(41981) }, + msg: "41981 shouldn't be present yet.", + }, + + // Add and remove one item + { + runCmd: func() { r.Add(41981) }, + expect: func() bool { return r.IsPresent(41981) }, + msg: "41981 should be contributing.", + }, + { + runCmd: func() { r.Remove(41981) }, + expect: func() bool { return !r.IsPresent(41981) }, + msg: "41981 shouldn't be contributing no more.", + }, + + // Two items present + { + runCmd: func() { r.Add(41981) }, + expect: func() bool { return r.IsPresent(41981) }, + msg: "41981 should be contributing.", + }, + { + runCmd: func() { r.Add(201701) }, + expect: func() bool { return r.IsPresent(41981) }, + msg: "201701 should be contributing.", + }, + + // Add 41981 2nd time + { + runCmd: func() { r.Add(41981) }, + expect: func() bool { return r.IsPresent(41981) }, + msg: "41981 should be still contributing.", + }, + { + runCmd: func() {}, + expect: func() bool { return r.items[0].value == 41981 }, + msg: "41981 is first item in list.", + }, + { + runCmd: func() { fmt.Printf("%+v", r.items) }, + expect: func() bool { return r.items[0].count == 2 }, + msg: "41981 should be present twice.", + }, + + // Remove 2nd 41981 + { + runCmd: func() { r.Remove(41981) }, + expect: func() bool { return r.IsPresent(41981) }, + msg: "41981 should still be contributing.", + }, + { + runCmd: func() {}, + expect: func() bool { return r.items[0].count == 1 }, + msg: "41981 should be present once.", + }, + + // Remove 201701 + { + runCmd: func() { r.Remove(201701) }, + expect: func() bool { return !r.IsPresent(201701) }, + msg: "201701 shouldn't be contributing no more.", + }, + } + + for i, test := range tests { + test.runCmd() + if !test.expect() { + t.Errorf("Test %d failed: %v", i, test.msg) + } + } +}