From e029ccf296d90b72aeda3d70eb26a539e752eacc Mon Sep 17 00:00:00 2001 From: myan Date: Wed, 4 Sep 2024 11:57:17 +0000 Subject: [PATCH] fix the race condition of mqtt Signed-off-by: myan add sending concurrently Signed-off-by: myan add err group Signed-off-by: myan --- protocol/mqtt_paho/v2/protocol.go | 13 ++- test/integration/mqtt_paho/concurrent_test.go | 98 +++++++++++++++++++ 2 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 test/integration/mqtt_paho/concurrent_test.go diff --git a/protocol/mqtt_paho/v2/protocol.go b/protocol/mqtt_paho/v2/protocol.go index 261fc6c37..8900a3921 100644 --- a/protocol/mqtt_paho/v2/protocol.go +++ b/protocol/mqtt_paho/v2/protocol.go @@ -20,7 +20,6 @@ import ( type Protocol struct { client *paho.Client - config *paho.ClientConfig connOption *paho.Connect publishOption *paho.Publish subscribeOption *paho.Subscribe @@ -89,7 +88,7 @@ func (p *Protocol) Send(ctx context.Context, m binding.Message, transformers ... var err error defer m.Finish(err) - msg := p.publishOption + msg := p.publishMsg() if cecontext.TopicFrom(ctx) != "" { msg.Topic = cecontext.TopicFrom(ctx) cecontext.WithTopic(ctx, "") @@ -107,6 +106,16 @@ func (p *Protocol) Send(ctx context.Context, m binding.Message, transformers ... return err } +// publishMsg generate a new paho.Publish message from the p.publishOption +func (p *Protocol) publishMsg() *paho.Publish { + return &paho.Publish{ + QoS: p.publishOption.QoS, + Retain: p.publishOption.Retain, + Topic: p.publishOption.Topic, + Properties: p.publishOption.Properties, + } +} + func (p *Protocol) OpenInbound(ctx context.Context) error { if p.subscribeOption == nil { return fmt.Errorf("the paho.Subscribe option must not be nil") diff --git a/test/integration/mqtt_paho/concurrent_test.go b/test/integration/mqtt_paho/concurrent_test.go new file mode 100644 index 000000000..d5b7d7b13 --- /dev/null +++ b/test/integration/mqtt_paho/concurrent_test.go @@ -0,0 +1,98 @@ +/* +Copyright 2024 The CloudEvents Authors +SPDX-License-Identifier: Apache-2.0 +*/ + +package mqtt_paho + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + + cloudevents "github.com/cloudevents/sdk-go/v2" + cecontext "github.com/cloudevents/sdk-go/v2/context" +) + +func TestConcurrentSendingEvent(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + topicName := "test-ce-client-" + uuid.New().String() + + readyCh := make(chan bool) + defer close(readyCh) + + senderNum := 10 // 10 gorutine to sending the events + eventNum := 1000 // each gorutine sender publishs 1,000 events + + var g errgroup.Group + + // start a receiver + c, err := cloudevents.NewClient(protocolFactory(ctx, t, topicName), cloudevents.WithUUIDs()) + require.NoError(t, err) + g.Go(func() error { + // verify all of events can be recieved + count := senderNum * eventNum + var mu sync.Mutex + return c.StartReceiver(ctx, func(event cloudevents.Event) { + mu.Lock() + defer mu.Unlock() + count-- + if count == 0 { + readyCh <- true + } + }) + }) + // wait for 5 seconds to ensure the receiver starts safely + time.Sleep(5 * time.Second) + + // start a sender client to pulish events concurrently + client, err := cloudevents.NewClient(protocolFactory(ctx, t, topicName), cloudevents.WithUUIDs()) + require.NoError(t, err) + + evt := cloudevents.NewEvent() + evt.SetType("com.cloudevents.sample.sent") + evt.SetSource("concurrent-sender") + err = evt.SetData(cloudevents.ApplicationJSON, map[string]interface{}{"message": "Hello, World!"}) + require.NoError(t, err) + + for i := 0; i < senderNum; i++ { + g.Go(func() error { + for j := 0; j < eventNum; j++ { + result := client.Send( + cecontext.WithTopic(ctx, topicName), + evt, + ) + if result != nil { + return result + } + } + return nil + }) + } + + // wait until all the events are received + handleEvent(ctx, readyCh, cancel, t) + + require.NoError(t, g.Wait()) +} + +func handleEvent(ctx context.Context, readyCh <-chan bool, cancel context.CancelFunc, t *testing.T) { + for { + select { + case <-ctx.Done(): + require.Fail(t, "Test failed: timeout reached before events were received") + return + case <-readyCh: + cancel() + t.Logf("Test passed: events successfully received") + return + } + } +}