diff --git a/waku/v2/protocol/filter/filter_test.go b/waku/v2/protocol/filter/filter_test.go index 68a70a82f..f1b9b8d4b 100644 --- a/waku/v2/protocol/filter/filter_test.go +++ b/waku/v2/protocol/filter/filter_test.go @@ -97,7 +97,7 @@ func (s *FilterTestSuite) makeWakuFilterFullNode(topic string) (*relay.WakuRelay node2Filter := NewWakuFilterFullNode(timesource.NewDefaultClock(), prometheus.DefaultRegisterer, s.log) node2Filter.SetHost(host) - sub := broadcaster.Register(topic) + sub := broadcaster.Register(protocol.NewContentFilter(topic)) err := node2Filter.Start(s.ctx, sub) s.Require().NoError(err) diff --git a/waku/v2/protocol/legacy_filter/waku_filter_test.go b/waku/v2/protocol/legacy_filter/waku_filter_test.go index 023d189c1..beb29f10d 100644 --- a/waku/v2/protocol/legacy_filter/waku_filter_test.go +++ b/waku/v2/protocol/legacy_filter/waku_filter_test.go @@ -82,7 +82,7 @@ func TestWakuFilter(t *testing.T) { node2Filter := NewWakuFilter(broadcaster, true, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger()) node2Filter.SetHost(host2) - sub := broadcaster.Register(testTopic) + sub := broadcaster.Register(protocol.NewContentFilter(testTopic)) err := node2Filter.Start(ctx, sub) require.NoError(t, err) @@ -170,7 +170,7 @@ func TestWakuFilterPeerFailure(t *testing.T) { require.NoError(t, broadcaster2.Start(context.Background())) node2Filter := NewWakuFilter(broadcaster2, true, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger(), WithTimeout(3*time.Second)) node2Filter.SetHost(host2) - sub := broadcaster.Register(testTopic) + sub := broadcaster.Register(protocol.NewContentFilter(testTopic)) err := node2Filter.Start(ctx, sub) require.NoError(t, err) diff --git a/waku/v2/protocol/relay/broadcast.go b/waku/v2/protocol/relay/broadcast.go index aea85e267..d201b467f 100644 --- a/waku/v2/protocol/relay/broadcast.go +++ b/waku/v2/protocol/relay/broadcast.go @@ -7,31 +7,32 @@ import ( "sync/atomic" "github.com/waku-org/go-waku/waku/v2/protocol" + "golang.org/x/exp/slices" ) -type chStore struct { +type subscriptions struct { mu sync.RWMutex - topicToChans map[string]map[int]chan *protocol.Envelope + topicsToSubs map[string]map[int]*Subscription //map of pubSubTopic to subscriptions id int } -func newChStore() chStore { - return chStore{ - topicToChans: make(map[string]map[int]chan *protocol.Envelope), +func newSubStore() subscriptions { + return subscriptions{ + topicsToSubs: make(map[string]map[int]*Subscription), } } -func (s *chStore) getNewCh(topic string, chLen int) Subscription { +func (s *subscriptions) getNewSubscription(contentFilter protocol.ContentFilter, chLen int) Subscription { ch := make(chan *protocol.Envelope, chLen) s.mu.Lock() defer s.mu.Unlock() s.id++ - // - if s.topicToChans[topic] == nil { - s.topicToChans[topic] = make(map[int]chan *protocol.Envelope) + pubsubTopic := contentFilter.PubsubTopic + if s.topicsToSubs[pubsubTopic] == nil { + s.topicsToSubs[pubsubTopic] = make(map[int]*Subscription) } id := s.id - s.topicToChans[topic][id] = ch - return Subscription{ + sub := Subscription{ + ID: id, // read only channel,will not block forever, returns once closed. Ch: ch, // Unsubscribe function is safe, can be called multiple times @@ -39,21 +40,25 @@ func (s *chStore) getNewCh(topic string, chLen int) Subscription { Unsubscribe: func() { s.mu.Lock() defer s.mu.Unlock() - if s.topicToChans[topic] == nil { + if s.topicsToSubs[pubsubTopic] == nil { return } - if ch := s.topicToChans[topic][id]; ch != nil { - close(ch) - delete(s.topicToChans[topic], id) + if sub := s.topicsToSubs[pubsubTopic][id]; sub != nil { + close(sub.Ch) + delete(s.topicsToSubs[pubsubTopic], id) } }, + contentFilter: contentFilter, } + s.topicsToSubs[pubsubTopic][id] = &sub + return sub } -func (s *chStore) broadcast(ctx context.Context, m *protocol.Envelope) { +func (s *subscriptions) broadcast(ctx context.Context, m *protocol.Envelope) { s.mu.RLock() defer s.mu.RUnlock() - for _, ch := range s.topicToChans[m.PubsubTopic()] { + for _, sub := range s.topicsToSubs[m.PubsubTopic()] { + select { // using ctx.Done for returning on cancellation is needed // reason: @@ -62,35 +67,43 @@ func (s *chStore) broadcast(ctx context.Context, m *protocol.Envelope) { // this will also block the chStore close function as it uses same mutex case <-ctx.Done(): return - case ch <- m: + default: + //Filter and notify only + // - if contentFilter doesn't have a contentTopic + // - if contentFilter has contentTopics and it matches with message + if len(sub.contentFilter.ContentTopicsList()) == 0 || (len(sub.contentFilter.ContentTopicsList()) > 0 && + slices.Contains[string](sub.contentFilter.ContentTopicsList(), m.Message().ContentTopic)) { + sub.Ch <- m + } } } - // send to all registered subscribers - for _, ch := range s.topicToChans[""] { + + // send to all wildcard subscribers + for _, sub := range s.topicsToSubs[""] { select { case <-ctx.Done(): return - case ch <- m: + case sub.Ch <- m: } } } -func (s *chStore) close() { +func (s *subscriptions) close() { s.mu.Lock() defer s.mu.Unlock() - for _, chans := range s.topicToChans { - for _, ch := range chans { - close(ch) + for _, subs := range s.topicsToSubs { + for _, sub := range subs { + close(sub.Ch) } } - s.topicToChans = nil + s.topicsToSubs = nil } // Broadcaster is used to create a fanout for an envelope that will be received by any subscriber interested in the topic of the message type Broadcaster interface { Start(ctx context.Context) error Stop() - Register(topic string, chLen ...int) Subscription + Register(contentFilter protocol.ContentFilter, chLen ...int) Subscription RegisterForAll(chLen ...int) Subscription Submit(*protocol.Envelope) } @@ -106,7 +119,7 @@ type broadcaster struct { cancel context.CancelFunc input chan *protocol.Envelope // - chStore chStore + chStore subscriptions running atomic.Bool } @@ -124,7 +137,7 @@ func (b *broadcaster) Start(ctx context.Context) error { } ctx, cancel := context.WithCancel(ctx) b.cancel = cancel - b.chStore = newChStore() + b.chStore = newSubStore() b.input = make(chan *protocol.Envelope, b.bufLen) go b.run(ctx) return nil @@ -154,15 +167,14 @@ func (b *broadcaster) Stop() { close(b.input) // close input channel } -// Register returns a subscription for an specific topic -func (b *broadcaster) Register(topic string, chLen ...int) Subscription { - return b.chStore.getNewCh(topic, getChLen(chLen)) +// Register returns a subscription for an specific pubsub topic and/or list of contentTopics +func (b *broadcaster) Register(contentFilter protocol.ContentFilter, chLen ...int) Subscription { + return b.chStore.getNewSubscription(contentFilter, getChLen(chLen)) } // RegisterForAll returns a subscription for all topics func (b *broadcaster) RegisterForAll(chLen ...int) Subscription { - - return b.chStore.getNewCh("", getChLen(chLen)) + return b.chStore.getNewSubscription(protocol.NewContentFilter(""), getChLen(chLen)) } func getChLen(chLen []int) int { diff --git a/waku/v2/protocol/relay/broadcast_test.go b/waku/v2/protocol/relay/broadcast_test.go index 59bd26b93..55422d228 100644 --- a/waku/v2/protocol/relay/broadcast_test.go +++ b/waku/v2/protocol/relay/broadcast_test.go @@ -46,7 +46,7 @@ func TestBroadcastSpecificTopic(t *testing.T) { for i := 0; i < 5; i++ { wg.Add(1) - sub := b.Register("abc") + sub := b.Register(protocol.NewContentFilter("abc")) go func() { defer wg.Done() @@ -66,7 +66,7 @@ func TestBroadcastSpecificTopic(t *testing.T) { func TestBroadcastCleanup(t *testing.T) { b := NewBroadcaster(100) require.NoError(t, b.Start(context.Background())) - sub := b.Register("test") + sub := b.Register(protocol.NewContentFilter("test")) b.Stop() <-sub.Ch sub.Unsubscribe() @@ -78,7 +78,7 @@ func TestBroadcastUnregisterSub(t *testing.T) { require.NoError(t, b.Start(context.Background())) subForAll := b.RegisterForAll() // unregister before submit - specificSub := b.Register("abc") + specificSub := b.Register(protocol.NewContentFilter("abc")) specificSub.Unsubscribe() // env := protocol.NewEnvelope(&pb.WakuMessage{}, utils.GetUnixEpoch(), "abc") diff --git a/waku/v2/protocol/relay/subscription.go b/waku/v2/protocol/relay/subscription.go index 670d6295b..b4720f3ae 100644 --- a/waku/v2/protocol/relay/subscription.go +++ b/waku/v2/protocol/relay/subscription.go @@ -4,6 +4,7 @@ import "github.com/waku-org/go-waku/waku/v2/protocol" // Subscription handles the details of a particular Topic subscription. There may be many subscriptions for a given topic. type Subscription struct { + ID int Unsubscribe func() Ch chan *protocol.Envelope contentFilter protocol.ContentFilter diff --git a/waku/v2/protocol/relay/waku_relay.go b/waku/v2/protocol/relay/waku_relay.go index 1e959320c..6a2ae05af 100644 --- a/waku/v2/protocol/relay/waku_relay.go +++ b/waku/v2/protocol/relay/waku_relay.go @@ -99,7 +99,8 @@ func msgIDFn(pmsg *pubsub_pb.Message) string { } // NewWakuRelay returns a new instance of a WakuRelay struct -func NewWakuRelay(bcaster Broadcaster, minPeersToPublish int, timesource timesource.Timesource, reg prometheus.Registerer, log *zap.Logger, opts ...pubsub.Option) *WakuRelay { +func NewWakuRelay(bcaster Broadcaster, minPeersToPublish int, timesource timesource.Timesource, + reg prometheus.Registerer, log *zap.Logger, opts ...pubsub.Option) *WakuRelay { w := new(WakuRelay) w.timesource = timesource w.wakuRelayTopics = make(map[string]*pubsub.Topic) @@ -232,6 +233,9 @@ func (w *WakuRelay) Start(ctx context.Context) error { } func (w *WakuRelay) start() error { + if w.bcaster == nil { + return fmt.Errorf("broadcaster not specified for relay") + } ps, err := pubsub.NewGossipSub(w.Context(), w.host, w.opts...) if err != nil { return err @@ -429,16 +433,14 @@ func (w *WakuRelay) subscribe(ctx context.Context, contentFilter waku_proto.Cont if err != nil { return nil, err } - /* TODO: Analyze what to do with this - if w.bcaster != nil { - _ = w.bcaster.Register(contentFilter.PubsubTopic, 1024) - } */ + + subscription := w.bcaster.Register(contentFilter, 1024) + // Create Content subscription - subscription := NewSubscription(contentFilter) w.topicsMutex.RLock() - w.contentSubs[pubSubTopic] = subscription + w.contentSubs[pubSubTopic] = &subscription w.topicsMutex.RUnlock() - subscriptions = append(subscriptions, subscription) + subscriptions = append(subscriptions, &subscription) go func() { <-ctx.Done() subscription.Unsubscribe() @@ -511,7 +513,9 @@ func (w *WakuRelay) Unsubscribe(ctx context.Context, contentFilter waku_proto.Co w.relaySubs[pubSubTopic].Cancel() delete(w.relaySubs, pubSubTopic) - //TODO: Any cancellation to be done? + //TODO: Unregister all subs from broadcaster + //cSub.Unsubscribe() + delete(w.contentSubs, pubSubTopic) evtHandler, ok := w.topicEvtHanders[pubSubTopic] @@ -579,9 +583,8 @@ func (w *WakuRelay) topicMsgHandler(pubsubTopic string, sub *pubsub.Subscription w.metrics.RecordMessage(envelope) - if w.bcaster != nil { - w.bcaster.Submit(envelope) - } + w.bcaster.Submit(envelope) + //Notify to all subscriptions for this topic sub, ok := w.contentSubs[pubsubTopic] if ok { diff --git a/waku/v2/protocol/relay/waku_relay_test.go b/waku/v2/protocol/relay/waku_relay_test.go index 45dc6c884..44511121b 100644 --- a/waku/v2/protocol/relay/waku_relay_test.go +++ b/waku/v2/protocol/relay/waku_relay_test.go @@ -27,15 +27,15 @@ func TestWakuRelay(t *testing.T) { host, err := tests.MakeHost(context.Background(), port, rand.Reader) require.NoError(t, err) - //bcaster := NewBroadcaster(10) - relay := NewWakuRelay(nil, 0, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger()) + bcaster := NewBroadcaster(10) + relay := NewWakuRelay(bcaster, 0, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger()) relay.SetHost(host) err = relay.Start(context.Background()) require.NoError(t, err) - //err = bcaster.Start(context.Background()) + err = bcaster.Start(context.Background()) require.NoError(t, err) - defer relay.Stop() + //defer relay.Stop() subs, err := relay.subscribe(context.Background(), protocol.NewContentFilter(testTopic)) //sub := bcaster.Register(testTopic) @@ -70,7 +70,6 @@ func TestWakuRelay(t *testing.T) { err = relay.Unsubscribe(ctx, protocol.NewContentFilter(testTopic)) require.NoError(t, err) - <-ctx.Done() } @@ -79,9 +78,10 @@ func createRelayNode(t *testing.T) (host.Host, *WakuRelay) { require.NoError(t, err) host, err := tests.MakeHost(context.Background(), port, rand.Reader) require.NoError(t, err) - - relay := NewWakuRelay(nil, 0, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger()) + bcaster := NewBroadcaster(10) + relay := NewWakuRelay(bcaster, 0, timesource.NewDefaultClock(), prometheus.DefaultRegisterer, utils.Logger()) relay.SetHost(host) + bcaster.Start(context.Background()) return host, relay }