diff --git a/allowlist_test.go b/allowlist_test.go index 7f5267e..2e8ce3f 100644 --- a/allowlist_test.go +++ b/allowlist_test.go @@ -11,6 +11,29 @@ import ( "github.com/multiformats/go-multiaddr" ) +func ExampleWithAllowlistedMultiaddrs() { + somePeer, err := test.RandPeerID() + if err != nil { + panic("Failed to generate somePeer") + } + + limits := DefaultLimits.AutoScale() + rcmgr, err := NewResourceManager(NewFixedLimiter(limits), WithAllowlistedMultiaddrs([]multiaddr.Multiaddr{ + // Any peer connecting from this IP address + multiaddr.StringCast("/ip4/1.2.3.4"), + // Only the specified peer from this address + multiaddr.StringCast("/ip4/2.2.3.4/p2p/" + peer.Encode(somePeer)), + // Only peers from this 1.2.3.0/24 IP address range + multiaddr.StringCast("/ip4/1.2.3.0/ipcidr/24"), + })) + if err != nil { + panic("Failed to start resource manager") + } + + // Use rcmgr as before + _ = rcmgr +} + func TestAllowedSimple(t *testing.T) { allowlist := newAllowlist() ma := multiaddr.StringCast("/ip4/1.2.3.4/tcp/1234") diff --git a/rcmgr.go b/rcmgr.go index 61f3a5e..b2cd18c 100644 --- a/rcmgr.go +++ b/rcmgr.go @@ -167,6 +167,18 @@ func (r *resourceManager) GetAllowlist() *Allowlist { return r.allowlist } +// GetAllowlist tries to get the allowlist from the given resourcemanager +// interface by checking to see if its concrete type is a resourceManager. +// Returns nil if it fails to get the allowlist. +func GetAllowlist(rcmgr network.ResourceManager) *Allowlist { + r, ok := rcmgr.(*resourceManager) + if !ok { + return nil + } + + return r.allowlist +} + func (r *resourceManager) ViewSystem(f func(network.ResourceScope) error) error { return f(r.system) } diff --git a/rcmgr_test.go b/rcmgr_test.go index 47e38da..6cf7e5f 100644 --- a/rcmgr_test.go +++ b/rcmgr_test.go @@ -1005,6 +1005,11 @@ func TestResourceManagerWithAllowlist(t *testing.T) { t.Fatal(err) } + ableToGetAllowlist := GetAllowlist(rcmgr) + if ableToGetAllowlist == nil { + t.Fatal("Expected to be able to get the allowlist") + } + // A connection comes in from a non-allowlisted ip address _, err = rcmgr.OpenConnection(network.DirInbound, true, multiaddr.StringCast("/ip4/1.2.3.5")) if err == nil {