Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for custom protocol matching function #440

Merged
merged 2 commits into from
Jul 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions gossipsub_matchfn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package pubsub

import (
"context"
"strings"
"testing"
"time"

"github.com/libp2p/go-libp2p-core/protocol"
)

func TestGossipSubMatchingFn(t *testing.T) {
customsubA100 := protocol.ID("/customsub_a/1.0.0")
customsubA101Beta := protocol.ID("/customsub_a/1.0.1-beta")
customsubB100 := protocol.ID("/customsub_b/1.0.0")

ctx, cancel := context.WithCancel(context.Background())
defer cancel()

h := getNetHosts(t, ctx, 4)
psubs := []*PubSub{
getGossipsub(ctx, h[0], WithProtocolMatchFn(protocolNameMatch), WithGossipSubProtocols([]protocol.ID{customsubA100, GossipSubID_v11}, GossipSubDefaultFeatures)),
getGossipsub(ctx, h[1], WithProtocolMatchFn(protocolNameMatch), WithGossipSubProtocols([]protocol.ID{customsubA101Beta}, GossipSubDefaultFeatures)),
getGossipsub(ctx, h[2], WithProtocolMatchFn(protocolNameMatch), WithGossipSubProtocols([]protocol.ID{GossipSubID_v11}, GossipSubDefaultFeatures)),
getGossipsub(ctx, h[3], WithProtocolMatchFn(protocolNameMatch), WithGossipSubProtocols([]protocol.ID{customsubB100}, GossipSubDefaultFeatures)),
}

connect(t, h[0], h[1])
connect(t, h[0], h[2])
connect(t, h[0], h[3])

// verify that the peers are connected
time.Sleep(2 * time.Second)
for i := 1; i < len(h); i++ {
if len(h[0].Network().ConnsToPeer(h[i].ID())) == 0 {
t.Fatal("expected a connection between peers")
}
}

// build the mesh
var subs []*Subscription
for _, ps := range psubs {
sub, err := ps.Subscribe("test")
if err != nil {
t.Fatal(err)
}
subs = append(subs, sub)
}

time.Sleep(time.Second)

// publish a message
msg := []byte("message")
psubs[0].Publish("test", msg)

assertReceive(t, subs[0], msg)
assertReceive(t, subs[1], msg) // Should match via semver over CustomSub name, ignoring the version
assertReceive(t, subs[2], msg) // Should match via GossipSubID_v11

// No message should be received because customsubA and customsubB have different names
ctxTimeout, timeoutCancel := context.WithTimeout(context.Background(), 1*time.Second)
defer timeoutCancel()
received := false
for {
msg, err := subs[3].Next(ctxTimeout)
if err != nil {
break
}
if msg != nil {
received = true
}
}
if received {
t.Fatal("Should not have received a message")
}
}

func protocolNameMatch(base string) func(string) bool {
return func(check string) bool {
baseName := strings.Split(string(base), "/")[1]
checkName := strings.Split(string(check), "/")[1]
return baseName == checkName
}
}
22 changes: 21 additions & 1 deletion pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ var (

var log = logging.Logger("pubsub")

type ProtocolMatchFn = func(string) func(string) bool

// PubSub is the implementation of the pubsub system.
type PubSub struct {
// atomic counter for seqnos
Expand Down Expand Up @@ -157,6 +159,9 @@ type PubSub struct {
// filter for tracking subscriptions in topics of interest; if nil, then we track all subscriptions
subFilter SubscriptionFilter

// protoMatchFunc is a matching function for protocol selection.
protoMatchFunc ProtocolMatchFn

ctx context.Context
}

Expand Down Expand Up @@ -292,7 +297,11 @@ func NewPubSub(ctx context.Context, h host.Host, rt PubSubRouter, opts ...Option
rt.Attach(ps)

for _, id := range rt.Protocols() {
h.SetStreamHandler(id, ps.handleNewStream)
if ps.protoMatchFunc != nil {
h.SetStreamHandlerMatch(id, ps.protoMatchFunc(string(id)), ps.handleNewStream)
} else {
h.SetStreamHandler(id, ps.handleNewStream)
}
}
h.Network().Notify((*PubSubNotif)(ps))

Expand Down Expand Up @@ -475,6 +484,17 @@ func WithMaxMessageSize(maxMessageSize int) Option {
}
}

// WithProtocolMatchFn sets a custom matching function for protocol selection to
// be used by the protocol handler on the Host's Mux. Should be combined with
// WithGossipSubProtocols feature function for checking if certain protocol features
// are supported
func WithProtocolMatchFn(m ProtocolMatchFn) Option {
return func(ps *PubSub) error {
ps.protoMatchFunc = m
return nil
}
}

// processLoop handles all inputs arriving on the channels
func (p *PubSub) processLoop(ctx context.Context) {
defer func() {
Expand Down