From 44ee898e911835bec09fef2a9d8d3175f7b85a03 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Thu, 29 Jul 2021 12:59:09 -0400 Subject: [PATCH 1/2] add support for custom protocol matching function --- pubsub.go | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/pubsub.go b/pubsub.go index 3b74fa60..ae6aa96e 100644 --- a/pubsub.go +++ b/pubsub.go @@ -36,6 +36,8 @@ var ( var log = logging.Logger("pubsub") +type MatchingFunction func(string) func(string) bool + // PubSub is the implementation of the pubsub system. type PubSub struct { // atomic counter for seqnos @@ -157,6 +159,9 @@ type PubSub struct { // filter for tracking subscriptions in topics of interest; if nil, then we track all subscriptions subFilter SubscriptionFilter + // protoMatchFunc is a matching function for protocol selection. + protoMatchFunc *MatchingFunction + ctx context.Context } @@ -235,6 +240,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option peerOutboundQueueSize: 32, signID: h.ID(), signKey: nil, + protoMatchFunc: nil, signPolicy: StrictSign, incoming: make(chan *RPC, 32), newPeers: make(chan struct{}, 1), @@ -292,7 +298,11 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option rt.Attach(ps) for _, id := range rt.Protocols() { - h.SetStreamHandler(id, ps.handleNewStream) + if ps.protoMatchFunc != nil { + h.SetStreamHandlerMatch(id, (*ps.protoMatchFunc)(string(id)), ps.handleNewStream) + } else { + h.SetStreamHandler(id, ps.handleNewStream) + } } h.Network().Notify((*PubSubNotif)(ps)) @@ -475,6 +485,15 @@ func WithMaxMessageSize(maxMessageSize int) Option { } } +// WithProtocolMatchFunction sets a custom matching function for protocol +// selection to be used by the protocol handler on the Host's Mux +func WithProtocolMatchFunction(m MatchingFunction) Option { + return func(ps *PubSub) error { + ps.protoMatchFunc = &m + return nil + } +} + // processLoop handles all inputs arriving on the channels func (p *PubSub) processLoop(ctx context.Context) { defer func() { From 1549a175928d5a64c391416908f80fdafae92823 Mon Sep 17 00:00:00 2001 From: Richard Ramos Date: Thu, 29 Jul 2021 17:54:04 -0400 Subject: [PATCH 2/2] fix: code review --- gossipsub_matchfn_test.go | 84 +++++++++++++++++++++++++++++++++++++++ pubsub.go | 17 ++++---- 2 files changed, 93 insertions(+), 8 deletions(-) create mode 100644 gossipsub_matchfn_test.go diff --git a/gossipsub_matchfn_test.go b/gossipsub_matchfn_test.go new file mode 100644 index 00000000..016c979c --- /dev/null +++ b/gossipsub_matchfn_test.go @@ -0,0 +1,84 @@ +package pubsub + +import ( + "context" + "strings" + "testing" + "time" + + "github.com/libp2p/go-libp2p-core/protocol" +) + +func TestGossipSubMatchingFn(t *testing.T) { + customsubA100 := protocol.ID("/customsub_a/1.0.0") + customsubA101Beta := protocol.ID("/customsub_a/1.0.1-beta") + customsubB100 := protocol.ID("/customsub_b/1.0.0") + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + h := getNetHosts(t, ctx, 4) + psubs := []*PubSub{ + getGossipsub(ctx, h[0], WithProtocolMatchFn(protocolNameMatch), WithGossipSubProtocols([]protocol.ID{customsubA100, GossipSubID_v11}, GossipSubDefaultFeatures)), + getGossipsub(ctx, h[1], WithProtocolMatchFn(protocolNameMatch), WithGossipSubProtocols([]protocol.ID{customsubA101Beta}, GossipSubDefaultFeatures)), + getGossipsub(ctx, h[2], WithProtocolMatchFn(protocolNameMatch), WithGossipSubProtocols([]protocol.ID{GossipSubID_v11}, GossipSubDefaultFeatures)), + getGossipsub(ctx, h[3], WithProtocolMatchFn(protocolNameMatch), WithGossipSubProtocols([]protocol.ID{customsubB100}, GossipSubDefaultFeatures)), + } + + connect(t, h[0], h[1]) + connect(t, h[0], h[2]) + connect(t, h[0], h[3]) + + // verify that the peers are connected + time.Sleep(2 * time.Second) + for i := 1; i < len(h); i++ { + if len(h[0].Network().ConnsToPeer(h[i].ID())) == 0 { + t.Fatal("expected a connection between peers") + } + } + + // build the mesh + var subs []*Subscription + for _, ps := range psubs { + sub, err := ps.Subscribe("test") + if err != nil { + t.Fatal(err) + } + subs = append(subs, sub) + } + + time.Sleep(time.Second) + + // publish a message + msg := []byte("message") + psubs[0].Publish("test", msg) + + assertReceive(t, subs[0], msg) + assertReceive(t, subs[1], msg) // Should match via semver over CustomSub name, ignoring the version + assertReceive(t, subs[2], msg) // Should match via GossipSubID_v11 + + // No message should be received because customsubA and customsubB have different names + ctxTimeout, timeoutCancel := context.WithTimeout(context.Background(), 1*time.Second) + defer timeoutCancel() + received := false + for { + msg, err := subs[3].Next(ctxTimeout) + if err != nil { + break + } + if msg != nil { + received = true + } + } + if received { + t.Fatal("Should not have received a message") + } +} + +func protocolNameMatch(base string) func(string) bool { + return func(check string) bool { + baseName := strings.Split(string(base), "/")[1] + checkName := strings.Split(string(check), "/")[1] + return baseName == checkName + } +} diff --git a/pubsub.go b/pubsub.go index ae6aa96e..f0296262 100644 --- a/pubsub.go +++ b/pubsub.go @@ -36,7 +36,7 @@ var ( var log = logging.Logger("pubsub") -type MatchingFunction func(string) func(string) bool +type ProtocolMatchFn = func(string) func(string) bool // PubSub is the implementation of the pubsub system. type PubSub struct { @@ -160,7 +160,7 @@ type PubSub struct { subFilter SubscriptionFilter // protoMatchFunc is a matching function for protocol selection. - protoMatchFunc *MatchingFunction + protoMatchFunc ProtocolMatchFn ctx context.Context } @@ -240,7 +240,6 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option peerOutboundQueueSize: 32, signID: h.ID(), signKey: nil, - protoMatchFunc: nil, signPolicy: StrictSign, incoming: make(chan *RPC, 32), newPeers: make(chan struct{}, 1), @@ -299,7 +298,7 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option for _, id := range rt.Protocols() { if ps.protoMatchFunc != nil { - h.SetStreamHandlerMatch(id, (*ps.protoMatchFunc)(string(id)), ps.handleNewStream) + h.SetStreamHandlerMatch(id, ps.protoMatchFunc(string(id)), ps.handleNewStream) } else { h.SetStreamHandler(id, ps.handleNewStream) } @@ -485,11 +484,13 @@ func WithMaxMessageSize(maxMessageSize int) Option { } } -// WithProtocolMatchFunction sets a custom matching function for protocol -// selection to be used by the protocol handler on the Host's Mux -func WithProtocolMatchFunction(m MatchingFunction) Option { +// WithProtocolMatchFn sets a custom matching function for protocol selection to +// be used by the protocol handler on the Host's Mux. Should be combined with +// WithGossipSubProtocols feature function for checking if certain protocol features +// are supported +func WithProtocolMatchFn(m ProtocolMatchFn) Option { return func(ps *PubSub) error { - ps.protoMatchFunc = &m + ps.protoMatchFunc = m return nil } }