From a7ecbb1676fd8da9825100e0b749a47afab64455 Mon Sep 17 00:00:00 2001 From: Josh Kim Date: Sat, 4 Dec 2021 12:17:18 -0800 Subject: [PATCH 1/2] format: gofmt on all source files Signed-off-by: Josh Kim --- bridge/broker_connector_config.go | 30 +- bridge/broker_connector_test.go | 9 +- bridge/example_connector_broker_tcp_test.go | 134 +-- bridge/example_connector_broker_ws_test.go | 124 +-- bus/channel.go | 256 +++--- bus/channel_event_handler.go | 10 +- bus/channel_manager_test.go | 395 ++++---- bus/doc.go | 2 +- bus/example_galactic_channels_test.go | 160 ++-- bus/fabric_endpoint_test.go | 647 +++++++------ bus/message_handler.go | 82 +- bus/message_test.go | 27 +- bus/monitor_event.go | 39 +- bus/mutation_store_stream.go | 122 +-- bus/store.go | 754 +++++++-------- bus/store_manager_test.go | 317 ++++--- bus/store_stream.go | 126 +-- bus/store_sync_service.go | 463 +++++----- bus/store_sync_service_test.go | 964 ++++++++++---------- bus/transaction.go | 350 +++---- bus/transaction_test.go | 421 +++++---- log/logger.go | 80 +- model/store_responses.go | 50 +- model/util.go | 44 +- plank/pkg/metrics/pageview_metric.go | 4 +- plank/pkg/middleware/cache_control.go | 8 +- plank/pkg/middleware/prometheus_metrics.go | 4 +- plank/pkg/server/prometheus.go | 4 +- plank/pkg/server/spa_config.go | 16 +- plank/utils/console_helpers.go | 14 +- service/fabric_service.go | 9 +- service/rest_service_test.go | 662 +++++++------- service/service_lifecycle_manager.go | 18 +- service/service_lifecycle_manager_test.go | 3 +- service/service_registry.go | 8 +- service/service_registry_test.go | 262 +++--- stompserver/config.go | 57 +- stompserver/errors.go | 18 +- 38 files changed, 3335 insertions(+), 3358 deletions(-) diff --git a/bridge/broker_connector_config.go b/bridge/broker_connector_config.go index 02fca78..f7eb2d4 100644 --- a/bridge/broker_connector_config.go +++ b/bridge/broker_connector_config.go @@ -10,25 +10,25 @@ import ( ) type WebSocketConfig struct { - WSPath string // if UseWS is true, set this to your websocket path (e.g. '/fabric') - UseTLS bool // use TLS encryption with WebSocket connection + WSPath string // if UseWS is true, set this to your websocket path (e.g. '/fabric') + UseTLS bool // use TLS encryption with WebSocket connection TLSConfig *tls.Config // TLS config for WebSocket connection - CertFile string // X509 certificate for TLS - KeyFile string // matching key file for the X509 certificate + CertFile string // X509 certificate for TLS + KeyFile string // matching key file for the X509 certificate } // BrokerConnectorConfig is a configuration used when connecting to a message broker type BrokerConnectorConfig struct { - Username string - Password string - ServerAddr string - UseWS bool // use WebSocket instead of TCP - WebSocketConfig *WebSocketConfig // WebSocket configuration for when UseWS is true - HostHeader string - HeartBeatOut time.Duration // outbound heartbeat interval (from client to server) - HeartBeatIn time.Duration // inbound heartbeat interval (from server to client) - STOMPHeader map[string]string // additional STOMP headers for handshake - HttpHeader http.Header // additional HTTP headers for WebSocket Upgrade + Username string + Password string + ServerAddr string + UseWS bool // use WebSocket instead of TCP + WebSocketConfig *WebSocketConfig // WebSocket configuration for when UseWS is true + HostHeader string + HeartBeatOut time.Duration // outbound heartbeat interval (from client to server) + HeartBeatIn time.Duration // inbound heartbeat interval (from server to client) + STOMPHeader map[string]string // additional STOMP headers for handshake + HttpHeader http.Header // additional HTTP headers for WebSocket Upgrade } // LoadX509KeyPairFromFiles loads from paths to x509 cert and its matching key files and initializes @@ -46,4 +46,4 @@ func (b *WebSocketConfig) LoadX509KeyPairFromFiles(certFile, keyFile string) err } } return err -} \ No newline at end of file +} diff --git a/bridge/broker_connector_test.go b/bridge/broker_connector_test.go index e8e4431..0d84837 100644 --- a/bridge/broker_connector_test.go +++ b/bridge/broker_connector_test.go @@ -7,10 +7,6 @@ import ( "bufio" "bytes" "fmt" - "github.com/go-stomp/stomp/v3/frame" - "github.com/go-stomp/stomp/v3/server" - "github.com/gorilla/websocket" - "github.com/stretchr/testify/assert" "log" "net" "net/http" @@ -18,6 +14,11 @@ import ( "net/url" "testing" "time" + + "github.com/go-stomp/stomp/v3/frame" + "github.com/go-stomp/stomp/v3/server" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/assert" ) var upgrader = websocket.Upgrader{} diff --git a/bridge/example_connector_broker_tcp_test.go b/bridge/example_connector_broker_tcp_test.go index 6eceb21..c21c919 100644 --- a/bridge/example_connector_broker_tcp_test.go +++ b/bridge/example_connector_broker_tcp_test.go @@ -4,75 +4,75 @@ package bridge_test import ( - "fmt" - "github.com/vmware/transport-go/bridge" - "github.com/vmware/transport-go/bus" + "fmt" + "github.com/vmware/transport-go/bridge" + "github.com/vmware/transport-go/bus" ) func Example_connectUsingBrokerViaTCP() { - // get a reference to the event bus. - b := bus.GetBus() - - // create a broker connector configuration, using WebSockets. - // Make sure you have a STOMP TCP server running like RabbitMQ - config := &bridge.BrokerConnectorConfig{ - Username: "guest", - Password: "guest", - ServerAddr: ":61613", - STOMPHeader: map[string]string{ - "access-token": "test", - }, - } - - // connect to broker. - c, err := b.ConnectBroker(config) - if err != nil { - fmt.Printf("unable to connect, error: %e", err) - } - defer c.Disconnect() - - // subscribe to our demo simple-stream - s, _ := c.Subscribe("/queue/sample") - - // set a counter - n := 0 - - // create a control chan - done := make(chan bool) - - // listen for messages - var consumer = func() { - for { - // listen for incoming messages from subscription. - m := <-s.GetMsgChannel() - n++ - - // get byte array. - d := m.Payload.([]byte) - - fmt.Printf("Message Received: %s\n", string(d)) - // listen for 5 messages then stop. - if n >= 5 { - break - } - } - done <- true - } - - // send messages - var producer = func() { - for i := 0; i < 5; i++ { - c.SendMessage("/queue/sample", "text/plain", []byte(fmt.Sprintf("message: %d", i))) - } - } - - // listen for incoming messages on subscription for destination /queue/sample - go consumer() - - // send some messages to the broker on destination /queue/sample - go producer() - - // wait for messages to be processed. - <-done + // get a reference to the event bus. + b := bus.GetBus() + + // create a broker connector configuration, using WebSockets. + // Make sure you have a STOMP TCP server running like RabbitMQ + config := &bridge.BrokerConnectorConfig{ + Username: "guest", + Password: "guest", + ServerAddr: ":61613", + STOMPHeader: map[string]string{ + "access-token": "test", + }, + } + + // connect to broker. + c, err := b.ConnectBroker(config) + if err != nil { + fmt.Printf("unable to connect, error: %e", err) + } + defer c.Disconnect() + + // subscribe to our demo simple-stream + s, _ := c.Subscribe("/queue/sample") + + // set a counter + n := 0 + + // create a control chan + done := make(chan bool) + + // listen for messages + var consumer = func() { + for { + // listen for incoming messages from subscription. + m := <-s.GetMsgChannel() + n++ + + // get byte array. + d := m.Payload.([]byte) + + fmt.Printf("Message Received: %s\n", string(d)) + // listen for 5 messages then stop. + if n >= 5 { + break + } + } + done <- true + } + + // send messages + var producer = func() { + for i := 0; i < 5; i++ { + c.SendMessage("/queue/sample", "text/plain", []byte(fmt.Sprintf("message: %d", i))) + } + } + + // listen for incoming messages on subscription for destination /queue/sample + go consumer() + + // send some messages to the broker on destination /queue/sample + go producer() + + // wait for messages to be processed. + <-done } diff --git a/bridge/example_connector_broker_ws_test.go b/bridge/example_connector_broker_ws_test.go index a70a1d6..afdde2b 100644 --- a/bridge/example_connector_broker_ws_test.go +++ b/bridge/example_connector_broker_ws_test.go @@ -4,70 +4,70 @@ package bridge_test import ( - "encoding/json" - "fmt" - "github.com/vmware/transport-go/bridge" - "github.com/vmware/transport-go/bus" - "github.com/vmware/transport-go/model" + "encoding/json" + "fmt" + "github.com/vmware/transport-go/bridge" + "github.com/vmware/transport-go/bus" + "github.com/vmware/transport-go/model" ) func Example_connectUsingBrokerViaWebSocket() { - // get a reference to the event bus. - b := bus.GetBus() - - // create a broker connector configuration, using WebSockets. - config := &bridge.BrokerConnectorConfig{ - Username: "guest", - Password: "guest", - ServerAddr: "appfabric.vmware.com", - WebSocketConfig: &bridge.WebSocketConfig{WSPath: "/fabric"}, - UseWS: true, - STOMPHeader: map[string]string{ - "access-token": "test", - }, - } - - // connect to broker. - c, err := b.ConnectBroker(config) - if err != nil { - fmt.Printf("unable to connect, error: %e", err) - } - - // subscribe to our demo simple-stream - s, _ := c.Subscribe("/topic/simple-stream") - - // set a counter - n := 0 - - // create a control chan - done := make(chan bool) - - var listener = func() { - for { - // listen for incoming messages from subscription. - m := <-s.GetMsgChannel() - - // unmarshal message. - r := &model.Response{} - d := m.Payload.([]byte) - json.Unmarshal(d, &r) - fmt.Printf("Message Received: %s\n", r.Payload.(string)) - - n++ - - // listen for 5 messages then stop. - if n >= 5 { - break - } - } - done <- true - } - - // listen for incoming messages on subscription. - go listener() - - <-done - - c.Disconnect() + // get a reference to the event bus. + b := bus.GetBus() + + // create a broker connector configuration, using WebSockets. + config := &bridge.BrokerConnectorConfig{ + Username: "guest", + Password: "guest", + ServerAddr: "appfabric.vmware.com", + WebSocketConfig: &bridge.WebSocketConfig{WSPath: "/fabric"}, + UseWS: true, + STOMPHeader: map[string]string{ + "access-token": "test", + }, + } + + // connect to broker. + c, err := b.ConnectBroker(config) + if err != nil { + fmt.Printf("unable to connect, error: %e", err) + } + + // subscribe to our demo simple-stream + s, _ := c.Subscribe("/topic/simple-stream") + + // set a counter + n := 0 + + // create a control chan + done := make(chan bool) + + var listener = func() { + for { + // listen for incoming messages from subscription. + m := <-s.GetMsgChannel() + + // unmarshal message. + r := &model.Response{} + d := m.Payload.([]byte) + json.Unmarshal(d, &r) + fmt.Printf("Message Received: %s\n", r.Payload.(string)) + + n++ + + // listen for 5 messages then stop. + if n >= 5 { + break + } + } + done <- true + } + + // listen for incoming messages on subscription. + go listener() + + <-done + + c.Disconnect() } diff --git a/bus/channel.go b/bus/channel.go index 440bfda..5c9f798 100644 --- a/bus/channel.go +++ b/bus/channel.go @@ -4,220 +4,220 @@ package bus import ( - "github.com/google/uuid" - "github.com/vmware/transport-go/bridge" - "github.com/vmware/transport-go/model" - "sync" - "sync/atomic" + "github.com/google/uuid" + "github.com/vmware/transport-go/bridge" + "github.com/vmware/transport-go/model" + "sync" + "sync/atomic" ) // Channel represents the stream and the subscribed event handlers waiting for ticks on the stream type Channel struct { - Name string `json:"string"` - eventHandlers []*channelEventHandler - galactic bool - galacticMappedDestination string - private bool - channelLock sync.Mutex - wg sync.WaitGroup - brokerSubs []*connectionSub - brokerConns []bridge.Connection - brokerMappedEvent chan bool + Name string `json:"string"` + eventHandlers []*channelEventHandler + galactic bool + galacticMappedDestination string + private bool + channelLock sync.Mutex + wg sync.WaitGroup + brokerSubs []*connectionSub + brokerConns []bridge.Connection + brokerMappedEvent chan bool } // Create a new Channel with the supplied Channel name. Returns a pointer to that Channel. func NewChannel(channelName string) *Channel { - c := &Channel{ - Name: channelName, - eventHandlers: []*channelEventHandler{}, - channelLock: sync.Mutex{}, - galactic: false, - private: false, - wg: sync.WaitGroup{}, - brokerMappedEvent: make(chan bool, 10), - brokerConns: []bridge.Connection{}, - brokerSubs: []*connectionSub{}} - return c + c := &Channel{ + Name: channelName, + eventHandlers: []*channelEventHandler{}, + channelLock: sync.Mutex{}, + galactic: false, + private: false, + wg: sync.WaitGroup{}, + brokerMappedEvent: make(chan bool, 10), + brokerConns: []bridge.Connection{}, + brokerSubs: []*connectionSub{}} + return c } // Mark the Channel as private func (channel *Channel) SetPrivate(private bool) { - channel.private = private + channel.private = private } // Mark the Channel as galactic func (channel *Channel) SetGalactic(mappedDestination string) { - channel.galactic = true - channel.galacticMappedDestination = mappedDestination + channel.galactic = true + channel.galacticMappedDestination = mappedDestination } // Mark the Channel as local func (channel *Channel) SetLocal() { - channel.galactic = false - channel.galacticMappedDestination = "" + channel.galactic = false + channel.galacticMappedDestination = "" } // Returns true is the Channel is marked as galactic func (channel *Channel) IsGalactic() bool { - return channel.galactic + return channel.galactic } // Returns true if the Channel is marked as private func (channel *Channel) IsPrivate() bool { - return channel.private + return channel.private } // Send a new message on this Channel, to all event handlers. func (channel *Channel) Send(message *model.Message) { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() - if eventHandlers := channel.eventHandlers; len(eventHandlers) > 0 { - - // if a handler is run once only, then the slice will be mutated mid cycle. - // copy slice to ensure that removed handler is still fired. - handlerDuplicate := make([]*channelEventHandler, 0, len(eventHandlers)) - handlerDuplicate = append(handlerDuplicate, eventHandlers...) - for n, eventHandler := range handlerDuplicate { - if eventHandler.runOnce && atomic.LoadInt64(&eventHandler.runCount) > 0 { - channel.removeEventHandler(n) // remove from slice. - continue - } - channel.wg.Add(1) - go channel.sendMessageToHandler(eventHandler, message) - } - } + channel.channelLock.Lock() + defer channel.channelLock.Unlock() + if eventHandlers := channel.eventHandlers; len(eventHandlers) > 0 { + + // if a handler is run once only, then the slice will be mutated mid cycle. + // copy slice to ensure that removed handler is still fired. + handlerDuplicate := make([]*channelEventHandler, 0, len(eventHandlers)) + handlerDuplicate = append(handlerDuplicate, eventHandlers...) + for n, eventHandler := range handlerDuplicate { + if eventHandler.runOnce && atomic.LoadInt64(&eventHandler.runCount) > 0 { + channel.removeEventHandler(n) // remove from slice. + continue + } + channel.wg.Add(1) + go channel.sendMessageToHandler(eventHandler, message) + } + } } // Check if the Channel has any registered subscribers func (channel *Channel) ContainsHandlers() bool { - return len(channel.eventHandlers) > 0 + return len(channel.eventHandlers) > 0 } // Send message to handler function func (channel *Channel) sendMessageToHandler(handler *channelEventHandler, message *model.Message) { - handler.callBackFunction(message) - atomic.AddInt64(&handler.runCount, 1) - channel.wg.Done() + handler.callBackFunction(message) + atomic.AddInt64(&handler.runCount, 1) + channel.wg.Done() } // Subscribe a new handler function. func (channel *Channel) subscribeHandler(handler *channelEventHandler) { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() - channel.eventHandlers = append(channel.eventHandlers, handler) + channel.channelLock.Lock() + defer channel.channelLock.Unlock() + channel.eventHandlers = append(channel.eventHandlers, handler) } func (channel *Channel) unsubscribeHandler(uuid *uuid.UUID) bool { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for i, handler := range channel.eventHandlers { - if handler.uuid.ID() == uuid.ID() { - channel.removeEventHandler(i) - return true - } - } - return false + for i, handler := range channel.eventHandlers { + if handler.uuid.ID() == uuid.ID() { + channel.removeEventHandler(i) + return true + } + } + return false } // Remove handler function from being subscribed to the Channel. func (channel *Channel) removeEventHandler(index int) { - numHandlers := len(channel.eventHandlers) - if numHandlers <= 0 { - return - } - if index >= numHandlers { - return - } + numHandlers := len(channel.eventHandlers) + if numHandlers <= 0 { + return + } + if index >= numHandlers { + return + } - // delete from event handler slice. - copy(channel.eventHandlers[index:], channel.eventHandlers[index+1:]) - channel.eventHandlers[numHandlers-1] = nil - channel.eventHandlers = channel.eventHandlers[:numHandlers-1] + // delete from event handler slice. + copy(channel.eventHandlers[index:], channel.eventHandlers[index+1:]) + channel.eventHandlers[numHandlers-1] = nil + channel.eventHandlers = channel.eventHandlers[:numHandlers-1] } func (channel *Channel) listenToBrokerSubscription(sub bridge.Subscription) { - for { - msg, m := <-sub.GetMsgChannel() - if m { - channel.Send(msg) - } else { - break - } - } + for { + msg, m := <-sub.GetMsgChannel() + if m { + channel.Send(msg) + } else { + break + } + } } func (channel *Channel) isBrokerSubscribed(sub bridge.Subscription) bool { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for _, cs := range channel.brokerSubs { - if sub.GetId().ID() == cs.s.GetId().ID() { - return true - } - } - return false + for _, cs := range channel.brokerSubs { + if sub.GetId().ID() == cs.s.GetId().ID() { + return true + } + } + return false } func (channel *Channel) isBrokerSubscribedToDestination(c bridge.Connection, dest string) bool { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for _, cs := range channel.brokerSubs { - if cs.s != nil && cs.s.GetDestination() == dest && cs.c != nil && cs.c.GetId() == c.GetId() { - return true - } - } - return false + for _, cs := range channel.brokerSubs { + if cs.s != nil && cs.s.GetDestination() == dest && cs.c != nil && cs.c.GetId() == c.GetId() { + return true + } + } + return false } func (channel *Channel) addBrokerConnection(c bridge.Connection) { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for _, brCon := range channel.brokerConns { - if brCon.GetId() == c.GetId() { - return - } - } + for _, brCon := range channel.brokerConns { + if brCon.GetId() == c.GetId() { + return + } + } - channel.brokerConns = append(channel.brokerConns, c) + channel.brokerConns = append(channel.brokerConns, c) } func (channel *Channel) removeBrokerConnections() { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - channel.brokerConns = []bridge.Connection{} + channel.brokerConns = []bridge.Connection{} } func (channel *Channel) addBrokerSubscription(conn bridge.Connection, sub bridge.Subscription) { - cs := &connectionSub{c: conn, s: sub} + cs := &connectionSub{c: conn, s: sub} - channel.channelLock.Lock() - channel.brokerSubs = append(channel.brokerSubs, cs) - channel.channelLock.Unlock() + channel.channelLock.Lock() + channel.brokerSubs = append(channel.brokerSubs, cs) + channel.channelLock.Unlock() - go channel.listenToBrokerSubscription(sub) + go channel.listenToBrokerSubscription(sub) } func (channel *Channel) removeBrokerSubscription(sub bridge.Subscription) { - channel.channelLock.Lock() - defer channel.channelLock.Unlock() + channel.channelLock.Lock() + defer channel.channelLock.Unlock() - for i, cs := range channel.brokerSubs { - if sub.GetId().ID() == cs.s.GetId().ID() { - channel.brokerSubs = removeSub(channel.brokerSubs, i) - } - } + for i, cs := range channel.brokerSubs { + if sub.GetId().ID() == cs.s.GetId().ID() { + channel.brokerSubs = removeSub(channel.brokerSubs, i) + } + } } func removeSub(s []*connectionSub, i int) []*connectionSub { - s[len(s)-1], s[i] = s[i], s[len(s)-1] - return s[:len(s)-1] + s[len(s)-1], s[i] = s[i], s[len(s)-1] + return s[:len(s)-1] } type connectionSub struct { - c bridge.Connection - s bridge.Subscription + c bridge.Connection + s bridge.Subscription } diff --git a/bus/channel_event_handler.go b/bus/channel_event_handler.go index 94aa133..d445eee 100644 --- a/bus/channel_event_handler.go +++ b/bus/channel_event_handler.go @@ -4,12 +4,12 @@ package bus import ( - "github.com/google/uuid" + "github.com/google/uuid" ) type channelEventHandler struct { - callBackFunction MessageHandlerFunction - runOnce bool - runCount int64 - uuid *uuid.UUID + callBackFunction MessageHandlerFunction + runOnce bool + runCount int64 + uuid *uuid.UUID } diff --git a/bus/channel_manager_test.go b/bus/channel_manager_test.go index 9c5f564..ee46206 100644 --- a/bus/channel_manager_test.go +++ b/bus/channel_manager_test.go @@ -4,312 +4,307 @@ package bus import ( - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/vmware/transport-go/model" - "sync" - "testing" - "time" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/vmware/transport-go/model" + "sync" + "testing" + "time" ) var testChannelManager ChannelManager var testChannelManagerChannelName = "melody" func createManager() (ChannelManager, EventBus) { - b := newTestEventBus() - manager := NewBusChannelManager(b) - return manager, b + b := newTestEventBus() + manager := NewBusChannelManager(b) + return manager, b } func TestChannelManager_Boot(t *testing.T) { - testChannelManager, _ = createManager() - assert.Len(t, testChannelManager.GetAllChannels(), 0) + testChannelManager, _ = createManager() + assert.Len(t, testChannelManager.GetAllChannels(), 0) } func TestChannelManager_CreateChannel(t *testing.T) { - var bus EventBus - testChannelManager, bus = createManager() + var bus EventBus + testChannelManager, bus = createManager() - wg := sync.WaitGroup{} - wg.Add(1) - bus.AddMonitorEventListener( - func(monitorEvt *MonitorEvent) { - if monitorEvt.EntityName == testChannelManagerChannelName { - assert.Equal(t, monitorEvt.EventType, ChannelCreatedEvt) - wg.Done() - } - }) + wg := sync.WaitGroup{} + wg.Add(1) + bus.AddMonitorEventListener( + func(monitorEvt *MonitorEvent) { + if monitorEvt.EntityName == testChannelManagerChannelName { + assert.Equal(t, monitorEvt.EventType, ChannelCreatedEvt) + wg.Done() + } + }) - testChannelManager.CreateChannel(testChannelManagerChannelName) + testChannelManager.CreateChannel(testChannelManagerChannelName) - wg.Wait() + wg.Wait() - assert.Len(t, testChannelManager.GetAllChannels(), 1) + assert.Len(t, testChannelManager.GetAllChannels(), 1) - fetchedChannel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.NotNil(t, fetchedChannel) - assert.True(t, testChannelManager.CheckChannelExists(testChannelManagerChannelName)) + fetchedChannel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.NotNil(t, fetchedChannel) + assert.True(t, testChannelManager.CheckChannelExists(testChannelManagerChannelName)) } func TestChannelManager_GetNotExistentChannel(t *testing.T) { - testChannelManager, _ = createManager() + testChannelManager, _ = createManager() - fetchedChannel, err := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.NotNil(t, err) - assert.Nil(t, fetchedChannel) + fetchedChannel, err := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.NotNil(t, err) + assert.Nil(t, fetchedChannel) } func TestChannelManager_DestroyChannel(t *testing.T) { - testChannelManager, _ = createManager() - - testChannelManager.CreateChannel(testChannelManagerChannelName) - testChannelManager.DestroyChannel(testChannelManagerChannelName) - fetchedChannel, err := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.Len(t, testChannelManager.GetAllChannels(), 0) - assert.NotNil(t, err) - assert.Nil(t, fetchedChannel) + testChannelManager, _ = createManager() + + testChannelManager.CreateChannel(testChannelManagerChannelName) + testChannelManager.DestroyChannel(testChannelManagerChannelName) + fetchedChannel, err := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.Len(t, testChannelManager.GetAllChannels(), 0) + assert.NotNil(t, err) + assert.Nil(t, fetchedChannel) } func TestChannelManager_SubscribeChannelHandler(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel(testChannelManagerChannelName) - - handler := func(*model.Message) {} - uuid, err := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) - assert.Nil(t, err) - assert.NotNil(t, uuid) - channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.Len(t, channel.eventHandlers, 1) + testChannelManager, _ = createManager() + testChannelManager.CreateChannel(testChannelManagerChannelName) + + handler := func(*model.Message) {} + uuid, err := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) + assert.Nil(t, err) + assert.NotNil(t, uuid) + channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.Len(t, channel.eventHandlers, 1) } func TestChannelManager_SubscribeChannelHandlerMissingChannel(t *testing.T) { - testChannelManager, _ = createManager() - handler := func(*model.Message) {} - _, err := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) - assert.NotNil(t, err) + testChannelManager, _ = createManager() + handler := func(*model.Message) {} + _, err := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) + assert.NotNil(t, err) } func TestChannelManager_UnsubscribeChannelHandler(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel(testChannelManagerChannelName) + testChannelManager, _ = createManager() + testChannelManager.CreateChannel(testChannelManagerChannelName) - handler := func(*model.Message) {} - uuid, _ := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) - channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.Len(t, channel.eventHandlers, 1) + handler := func(*model.Message) {} + uuid, _ := testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) + channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.Len(t, channel.eventHandlers, 1) - err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, uuid) - assert.Nil(t, err) - assert.Len(t, channel.eventHandlers, 0) + err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, uuid) + assert.Nil(t, err) + assert.Len(t, channel.eventHandlers, 0) } func TestChannelManager_UnsubscribeChannelHandlerMissingChannel(t *testing.T) { - testChannelManager, _ = createManager() - uuid := uuid.New() - err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, &uuid) - assert.NotNil(t, err) + testChannelManager, _ = createManager() + uuid := uuid.New() + err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, &uuid) + assert.NotNil(t, err) } func TestChannelManager_UnsubscribeChannelHandlerNoId(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel(testChannelManagerChannelName) - - handler := func(*model.Message) {} - testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) - channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) - assert.Len(t, channel.eventHandlers, 1) - id := uuid.New() - err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, &id) - assert.NotNil(t, err) - assert.Len(t, channel.eventHandlers, 1) + testChannelManager, _ = createManager() + testChannelManager.CreateChannel(testChannelManagerChannelName) + + handler := func(*model.Message) {} + testChannelManager.SubscribeChannelHandler(testChannelManagerChannelName, handler, false) + channel, _ := testChannelManager.GetChannel(testChannelManagerChannelName) + assert.Len(t, channel.eventHandlers, 1) + id := uuid.New() + err := testChannelManager.UnsubscribeChannelHandler(testChannelManagerChannelName, &id) + assert.NotNil(t, err) + assert.Len(t, channel.eventHandlers, 1) } func TestChannelManager_TestWaitForGroupOnBadChannel(t *testing.T) { - testChannelManager, _ = createManager() - err := testChannelManager.WaitForChannel("unknown") - assert.Error(t, err, "no such Channel as 'unknown'") + testChannelManager, _ = createManager() + err := testChannelManager.WaitForChannel("unknown") + assert.Error(t, err, "no such Channel as 'unknown'") } func TestChannelManager_TestGalacticChannelOpen(t *testing.T) { - testChannelManager, _ = createManager() - galacticChannel := testChannelManager.CreateChannel(testChannelManagerChannelName) - id := uuid.New() + testChannelManager, _ = createManager() + galacticChannel := testChannelManager.CreateChannel(testChannelManagerChannelName) + id := uuid.New() - // mark channel as galactic. + // mark channel as galactic. - subId := uuid.New() - sub := &MockBridgeSubscription{ - Id: &subId, - } + subId := uuid.New() + sub := &MockBridgeSubscription{ + Id: &subId, + } - c := &MockBridgeConnection{Id: &id} - c.On("Subscribe", "/topic/testy-test").Return(sub, nil).Once() - e := testChannelManager.MarkChannelAsGalactic(testChannelManagerChannelName, "/topic/testy-test", c) + c := &MockBridgeConnection{Id: &id} + c.On("Subscribe", "/topic/testy-test").Return(sub, nil).Once() + e := testChannelManager.MarkChannelAsGalactic(testChannelManagerChannelName, "/topic/testy-test", c) - assert.Nil(t, e) - c.AssertExpectations(t) + assert.Nil(t, e) + c.AssertExpectations(t) - assert.True(t, galacticChannel.galactic) + assert.True(t, galacticChannel.galactic) - assert.Equal(t, len(galacticChannel.brokerConns), 1) - assert.Equal(t, galacticChannel.brokerConns[0], c) + assert.Equal(t, len(galacticChannel.brokerConns), 1) + assert.Equal(t, galacticChannel.brokerConns[0], c) - assert.Equal(t, len(galacticChannel.brokerSubs), 1) - assert.Equal(t, galacticChannel.brokerSubs[0].s, sub) - assert.Equal(t, galacticChannel.brokerSubs[0].c, c) + assert.Equal(t, len(galacticChannel.brokerSubs), 1) + assert.Equal(t, galacticChannel.brokerSubs[0].s, sub) + assert.Equal(t, galacticChannel.brokerSubs[0].c, c) - testChannelManager.MarkChannelAsLocal(testChannelManagerChannelName) - assert.False(t, galacticChannel.galactic) + testChannelManager.MarkChannelAsLocal(testChannelManagerChannelName) + assert.False(t, galacticChannel.galactic) - assert.Equal(t, len(galacticChannel.brokerConns), 0) - assert.Equal(t, len(galacticChannel.brokerSubs), 0) + assert.Equal(t, len(galacticChannel.brokerConns), 0) + assert.Equal(t, len(galacticChannel.brokerSubs), 0) } func TestChannelManager_TestGalacticChannelOpenError(t *testing.T) { - // channel is not open / does not exist, so this should fail. - e := testChannelManager.MarkChannelAsGalactic(evtbusTestChannelName, "/topic/testy-test", nil) - assert.Error(t, e) + // channel is not open / does not exist, so this should fail. + e := testChannelManager.MarkChannelAsGalactic(evtbusTestChannelName, "/topic/testy-test", nil) + assert.Error(t, e) } func TestChannelManager_TestGalacticChannelCloseError(t *testing.T) { - // channel is not open / does not exist, so this should fail. - e := testChannelManager.MarkChannelAsLocal(evtbusTestChannelName) - assert.Error(t, e) + // channel is not open / does not exist, so this should fail. + e := testChannelManager.MarkChannelAsLocal(evtbusTestChannelName) + assert.Error(t, e) } func TestChannelManager_TestListenToMonitorGalactic(t *testing.T) { - myChan := "mychan" + myChan := "mychan" - b := newTestEventBus() + b := newTestEventBus() - testChannelManager = b.GetChannelManager() - c := testChannelManager.CreateChannel(myChan) + testChannelManager = b.GetChannelManager() + c := testChannelManager.CreateChannel(myChan) + // mark channel as galactic. + id := uuid.New() + subId := uuid.New() + mockSub := &MockBridgeSubscription{ + Id: &subId, + Channel: make(chan *model.Message, 10), + Destination: "/queue/hiya", + } + mockCon := &MockBridgeConnection{Id: &id} + mockCon.On("Subscribe", "/queue/hiya").Return(mockSub, nil).Once() - // mark channel as galactic. - id := uuid.New() - subId := uuid.New() - mockSub := &MockBridgeSubscription{ - Id: &subId, - Channel: make(chan *model.Message, 10), - Destination: "/queue/hiya", - } + x := 0 - mockCon := &MockBridgeConnection{Id: &id} - mockCon.On("Subscribe", "/queue/hiya").Return(mockSub, nil).Once() + h, e := b.ListenOnce(myChan) + assert.Nil(t, e) - x := 0 + var m1 = make(chan bool) + var m2 = make(chan bool) - h, e := b.ListenOnce(myChan) - assert.Nil(t, e) + h.Handle( + func(msg *model.Message) { + x++ + m1 <- true + }, + func(err error) { - var m1 = make(chan bool) - var m2 = make(chan bool) + }) + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon) + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon) // double up for fun + <-c.brokerMappedEvent + assert.Len(t, c.brokerConns, 1) + mockSub.GetMsgChannel() <- &model.Message{Payload: "test-message", Direction: model.ResponseDir} + <-m1 - h.Handle( - func(msg *model.Message) { - x++ - m1 <- true - }, - func(err error) { + // lets add another connection to the same channel. - }) + id2 := uuid.New() + subId2 := uuid.New() + mockSub2 := &MockBridgeSubscription{ + Id: &subId2, + Channel: make(chan *model.Message, 10), + Destination: "/queue/hiya", + } - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon) - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon) // double up for fun - <-c.brokerMappedEvent - assert.Len(t, c.brokerConns, 1) - mockSub.GetMsgChannel() <- &model.Message{Payload: "test-message", Direction: model.ResponseDir} - <-m1 + mockCon2 := &MockBridgeConnection{Id: &id2} + mockCon2.On("Subscribe", "/queue/hiya").Return(mockSub2, nil).Once() - // lets add another connection to the same channel. + h, e = b.ListenOnce(myChan) - id2 := uuid.New() - subId2 := uuid.New() - mockSub2 := &MockBridgeSubscription{ - Id: &subId2, - Channel: make(chan *model.Message, 10), - Destination: "/queue/hiya", - } + h.Handle( + func(msg *model.Message) { + x++ + m2 <- true + }, + func(err error) {}) - mockCon2 := &MockBridgeConnection{Id: &id2} - mockCon2.On("Subscribe", "/queue/hiya").Return(mockSub2, nil).Once() + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon2) + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon2) // trigger double (should ignore) - h, e = b.ListenOnce(myChan) + select { + case <-c.brokerMappedEvent: + case <-time.After(5 * time.Second): + assert.FailNow(t, "TestChannelManager_TestListenToMonitorGalactic timeout on brokerMappedEvent") + } - h.Handle( - func(msg *model.Message) { - x++ - m2 <- true - }, - func(err error) {}) + mockSub.GetMsgChannel() <- &model.Message{Payload: "Hi baby melody!", Direction: model.ResponseDir} - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon2) - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/hiya", mockCon2) // trigger double (should ignore) - - select { - case <-c.brokerMappedEvent: - case <-time.After(5 * time.Second): - assert.FailNow(t, "TestChannelManager_TestListenToMonitorGalactic timeout on brokerMappedEvent") - } - - mockSub.GetMsgChannel() <- &model.Message{Payload: "Hi baby melody!", Direction: model.ResponseDir} - - <-m2 - assert.Equal(t, 2, x) + <-m2 + assert.Equal(t, 2, x) } - // This test performs a end to end run of the monitor. // it will create a ws broker subscription, map it to a single channel // then it will unsubscribe and check that the unsubscription went through ok. func TestChannelManager_TestListenToMonitorLocal(t *testing.T) { - myChan := "mychan-local" - - b := newTestEventBus() + myChan := "mychan-local" - // run ws broker - testChannelManager = b.GetChannelManager() + b := newTestEventBus() - c := testChannelManager.CreateChannel(myChan) + // run ws broker + testChannelManager = b.GetChannelManager() + c := testChannelManager.CreateChannel(myChan) - subId := uuid.New() - sub := &MockBridgeSubscription{ - Id: &subId, - } + subId := uuid.New() + sub := &MockBridgeSubscription{ + Id: &subId, + } - id := uuid.New() - mockCon := &MockBridgeConnection{Id: &id} - mockCon.On("Subscribe", "/queue/seeya").Return(sub, nil).Once() + id := uuid.New() + mockCon := &MockBridgeConnection{Id: &id} + mockCon.On("Subscribe", "/queue/seeya").Return(sub, nil).Once() - testChannelManager.MarkChannelAsGalactic(myChan, "/queue/seeya", mockCon) - <-c.brokerMappedEvent - assert.Len(t, c.brokerConns, 1) + testChannelManager.MarkChannelAsGalactic(myChan, "/queue/seeya", mockCon) + <-c.brokerMappedEvent + assert.Len(t, c.brokerConns, 1) - testChannelManager.MarkChannelAsLocal(myChan) - <-c.brokerMappedEvent - assert.Len(t, c.brokerConns, 0) - assert.Len(t, c.brokerSubs, 0) + testChannelManager.MarkChannelAsLocal(myChan) + <-c.brokerMappedEvent + assert.Len(t, c.brokerConns, 0) + assert.Len(t, c.brokerSubs, 0) } func TestChannelManager_TestGalacticMonitorInvalidChannel(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel("fun-chan") + testChannelManager, _ = createManager() + testChannelManager.CreateChannel("fun-chan") - err := testChannelManager.MarkChannelAsGalactic("fun-chan", "/queue/woo", nil) - assert.Nil(t, err) + err := testChannelManager.MarkChannelAsGalactic("fun-chan", "/queue/woo", nil) + assert.Nil(t, err) } func TestChannelManager_TestLocalMonitorInvalidChannel(t *testing.T) { - testChannelManager, _ = createManager() - testChannelManager.CreateChannel("fun-chan") + testChannelManager, _ = createManager() + testChannelManager.CreateChannel("fun-chan") - err := testChannelManager.MarkChannelAsLocal("fun-chan") - assert.Nil(t, err) + err := testChannelManager.MarkChannelAsLocal("fun-chan") + assert.Nil(t, err) } diff --git a/bus/doc.go b/bus/doc.go index e3d8cc2..dc4e790 100644 --- a/bus/doc.go +++ b/bus/doc.go @@ -3,5 +3,5 @@ /* Package bus contains all things bus. - */ +*/ package bus diff --git a/bus/example_galactic_channels_test.go b/bus/example_galactic_channels_test.go index 99cea86..2cfaef8 100644 --- a/bus/example_galactic_channels_test.go +++ b/bus/example_galactic_channels_test.go @@ -4,87 +4,87 @@ package bus_test import ( - "encoding/json" - "fmt" - "github.com/vmware/transport-go/bridge" - "github.com/vmware/transport-go/bus" - "github.com/vmware/transport-go/model" - "log" + "encoding/json" + "fmt" + "github.com/vmware/transport-go/bridge" + "github.com/vmware/transport-go/bus" + "github.com/vmware/transport-go/model" + "log" ) func Example_usingGalacticChannels() { - // get a pointer to the bus. - b := bus.GetBus() - - // get a pointer to the channel manager - cm := b.GetChannelManager() - - channel := "my-stream" - cm.CreateChannel(channel) - - // create done signal - var done = make(chan bool) - - // listen to stream of messages coming in on channel. - h, err := b.ListenStream(channel) - - if err != nil { - log.Panicf("unable to listen to channel stream, error: %e", err) - } - - count := 0 - - // listen for five messages and then exit, send a completed signal on channel. - h.Handle( - func(msg *model.Message) { - - // unmarshal the payload into a Response object (used by fabric services) - r := &model.Response{} - d := msg.Payload.([]byte) - json.Unmarshal(d, &r) - fmt.Printf("Stream Ticked: %s\n", r.Payload.(string)) - count++ - if count >=5 { - done <- true - } - }, - func(err error) { - log.Panicf("error received on channel %e", err) - }) - - // create a broker connector config, in this case, we will connect to the application fabric demo endpoint. - config := &bridge.BrokerConnectorConfig{ - Username: "guest", - Password: "guest", - ServerAddr: "appfabric.vmware.com", - WebSocketConfig: &bridge.WebSocketConfig{ - WSPath: "/fabric", - }, - UseWS: true} - - // connect to broker. - c, err := b.ConnectBroker(config) - if err != nil { - log.Panicf("unable to connect to fabric, error: %e", err) - } - - // mark our local channel as galactic and map it to our connection and the /topic/simple-stream service - // running on appfabric.vmware.com - err = cm.MarkChannelAsGalactic(channel, "/topic/simple-stream", c) - if err != nil { - log.Panicf("unable to map local channel to broker destination: %e", err) - } - - // wait for done signal - <-done - - // mark channel as local (unsubscribe from all mappings) - err = cm.MarkChannelAsLocal(channel) - if err != nil { - log.Panicf("unable to unsubscribe, error: %e", err) - } - err = c.Disconnect() - if err != nil { - log.Panicf("unable to disconnect, error: %e", err) - } + // get a pointer to the bus. + b := bus.GetBus() + + // get a pointer to the channel manager + cm := b.GetChannelManager() + + channel := "my-stream" + cm.CreateChannel(channel) + + // create done signal + var done = make(chan bool) + + // listen to stream of messages coming in on channel. + h, err := b.ListenStream(channel) + + if err != nil { + log.Panicf("unable to listen to channel stream, error: %e", err) + } + + count := 0 + + // listen for five messages and then exit, send a completed signal on channel. + h.Handle( + func(msg *model.Message) { + + // unmarshal the payload into a Response object (used by fabric services) + r := &model.Response{} + d := msg.Payload.([]byte) + json.Unmarshal(d, &r) + fmt.Printf("Stream Ticked: %s\n", r.Payload.(string)) + count++ + if count >= 5 { + done <- true + } + }, + func(err error) { + log.Panicf("error received on channel %e", err) + }) + + // create a broker connector config, in this case, we will connect to the application fabric demo endpoint. + config := &bridge.BrokerConnectorConfig{ + Username: "guest", + Password: "guest", + ServerAddr: "appfabric.vmware.com", + WebSocketConfig: &bridge.WebSocketConfig{ + WSPath: "/fabric", + }, + UseWS: true} + + // connect to broker. + c, err := b.ConnectBroker(config) + if err != nil { + log.Panicf("unable to connect to fabric, error: %e", err) + } + + // mark our local channel as galactic and map it to our connection and the /topic/simple-stream service + // running on appfabric.vmware.com + err = cm.MarkChannelAsGalactic(channel, "/topic/simple-stream", c) + if err != nil { + log.Panicf("unable to map local channel to broker destination: %e", err) + } + + // wait for done signal + <-done + + // mark channel as local (unsubscribe from all mappings) + err = cm.MarkChannelAsLocal(channel) + if err != nil { + log.Panicf("unable to unsubscribe, error: %e", err) + } + err = c.Disconnect() + if err != nil { + log.Panicf("unable to disconnect, error: %e", err) + } } diff --git a/bus/fabric_endpoint_test.go b/bus/fabric_endpoint_test.go index 894b90d..a7b4c56 100644 --- a/bus/fabric_endpoint_test.go +++ b/bus/fabric_endpoint_test.go @@ -4,393 +4,390 @@ package bus import ( - "encoding/json" - "errors" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/vmware/transport-go/model" - "github.com/vmware/transport-go/stompserver" - "sync" - "testing" + "encoding/json" + "errors" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/vmware/transport-go/model" + "github.com/vmware/transport-go/stompserver" + "sync" + "testing" ) type MockStompServerMessage struct { - Destination string `json:"destination"` - Payload []byte `json:"payload"` - conId string + Destination string `json:"destination"` + Payload []byte `json:"payload"` + conId string } type MockStompServer struct { - started bool - sentMessages []MockStompServerMessage - subscribeHandlerFunction stompserver.SubscribeHandlerFunction - connectionEventCallbacks map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent) - unsubscribeHandlerFunction stompserver.UnsubscribeHandlerFunction - applicationRequestHandlerFunction stompserver.ApplicationRequestHandlerFunction - wg *sync.WaitGroup + started bool + sentMessages []MockStompServerMessage + subscribeHandlerFunction stompserver.SubscribeHandlerFunction + connectionEventCallbacks map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent) + unsubscribeHandlerFunction stompserver.UnsubscribeHandlerFunction + applicationRequestHandlerFunction stompserver.ApplicationRequestHandlerFunction + wg *sync.WaitGroup } -func(s *MockStompServer) Start() { - s.started = true +func (s *MockStompServer) Start() { + s.started = true } -func(s *MockStompServer) Stop() { - s.started = false +func (s *MockStompServer) Stop() { + s.started = false } -func(s *MockStompServer) SendMessage(destination string, messageBody []byte) { - s.sentMessages = append(s.sentMessages, - MockStompServerMessage{Destination: destination, Payload: messageBody}) +func (s *MockStompServer) SendMessage(destination string, messageBody []byte) { + s.sentMessages = append(s.sentMessages, + MockStompServerMessage{Destination: destination, Payload: messageBody}) - if s.wg != nil { - s.wg.Done() - } + if s.wg != nil { + s.wg.Done() + } } -func(s *MockStompServer) SendMessageToClient(conId string, destination string, messageBody []byte) { - s.sentMessages = append(s.sentMessages, - MockStompServerMessage{Destination: destination, Payload: messageBody, conId: conId}) +func (s *MockStompServer) SendMessageToClient(conId string, destination string, messageBody []byte) { + s.sentMessages = append(s.sentMessages, + MockStompServerMessage{Destination: destination, Payload: messageBody, conId: conId}) - if s.wg != nil { - s.wg.Done() - } + if s.wg != nil { + s.wg.Done() + } } -func(s *MockStompServer) OnUnsubscribeEvent(callback stompserver.UnsubscribeHandlerFunction) { - s.unsubscribeHandlerFunction = callback +func (s *MockStompServer) OnUnsubscribeEvent(callback stompserver.UnsubscribeHandlerFunction) { + s.unsubscribeHandlerFunction = callback } -func(s *MockStompServer) OnApplicationRequest(callback stompserver.ApplicationRequestHandlerFunction) { - s.applicationRequestHandlerFunction = callback +func (s *MockStompServer) OnApplicationRequest(callback stompserver.ApplicationRequestHandlerFunction) { + s.applicationRequestHandlerFunction = callback } -func(s *MockStompServer) OnSubscribeEvent(callback stompserver.SubscribeHandlerFunction) { - s.subscribeHandlerFunction = callback +func (s *MockStompServer) OnSubscribeEvent(callback stompserver.SubscribeHandlerFunction) { + s.subscribeHandlerFunction = callback } func (s *MockStompServer) SetConnectionEventCallback(connEventType stompserver.StompSessionEventType, cb func(connEvent *stompserver.ConnEvent)) { - s.connectionEventCallbacks[connEventType] = cb - cb(&stompserver.ConnEvent{ConnId: "id"}) + s.connectionEventCallbacks[connEventType] = cb + cb(&stompserver.ConnEvent{ConnId: "id"}) } func newTestFabricEndpoint(bus EventBus, config EndpointConfig) (*fabricEndpoint, *MockStompServer) { - fe := newFabricEndpoint(bus, nil, config).(*fabricEndpoint) - ms := &MockStompServer{connectionEventCallbacks: make(map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent))} + fe := newFabricEndpoint(bus, nil, config).(*fabricEndpoint) + ms := &MockStompServer{connectionEventCallbacks: make(map[stompserver.StompSessionEventType]func(event *stompserver.ConnEvent))} - fe.server = ms - fe.initHandlers() + fe.server = ms + fe.initHandlers() - return fe, ms + return fe, ms } func TestFabricEndpoint_newFabricEndpoint(t *testing.T) { - fe, _ := newTestFabricEndpoint(nil, EndpointConfig{ - TopicPrefix: "/topic", - AppRequestPrefix: "/pub", - Heartbeat: 0, - }) - - assert.NotNil(t, fe) - assert.Equal(t, fe.config.TopicPrefix, "/topic/") - assert.Equal(t, fe.config.AppRequestPrefix, "/pub/") - - fe, _ = newTestFabricEndpoint(nil, EndpointConfig{ - TopicPrefix: "/topic/", - AppRequestPrefix: "", - Heartbeat: 0, - }) - - assert.Equal(t, fe.config.TopicPrefix, "/topic/") - assert.Equal(t, fe.config.AppRequestPrefix, "") + fe, _ := newTestFabricEndpoint(nil, EndpointConfig{ + TopicPrefix: "/topic", + AppRequestPrefix: "/pub", + Heartbeat: 0, + }) + + assert.NotNil(t, fe) + assert.Equal(t, fe.config.TopicPrefix, "/topic/") + assert.Equal(t, fe.config.AppRequestPrefix, "/pub/") + + fe, _ = newTestFabricEndpoint(nil, EndpointConfig{ + TopicPrefix: "/topic/", + AppRequestPrefix: "", + Heartbeat: 0, + }) + + assert.Equal(t, fe.config.TopicPrefix, "/topic/") + assert.Equal(t, fe.config.AppRequestPrefix, "") } func TestFabricEndpoint_StartAndStop(t *testing.T) { - fe, mockServer := newTestFabricEndpoint(nil, EndpointConfig{}) - assert.Equal(t, mockServer.started, false) - fe.Start() - assert.Equal(t, mockServer.started, true) - fe.Stop() - assert.Equal(t, mockServer.started, false) + fe, mockServer := newTestFabricEndpoint(nil, EndpointConfig{}) + assert.Equal(t, mockServer.started, false) + fe.Start() + assert.Equal(t, mockServer.started, true) + fe.Stop() + assert.Equal(t, mockServer.started, false) } func TestFabricEndpoint_SubscribeEvent(t *testing.T) { - bus := newTestEventBus() - bus.GetChannelManager().CreateChannel(STOMP_SESSION_NOTIFY_CHANNEL) // used for internal channel protection test - fe, mockServer := newTestFabricEndpoint(bus, - EndpointConfig{TopicPrefix: "/topic", UserQueuePrefix:"/user/queue"}) - - bus.GetChannelManager().CreateChannel("test-service") - - monitorWg := sync.WaitGroup{} - var monitorEvents []*MonitorEvent - bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) { - monitorEvents = append(monitorEvents, monitorEvt) - monitorWg.Done() - }, FabricEndpointSubscribeEvt) - - // subscribe to invalid topic - mockServer.subscribeHandlerFunction("con1", "sub1", "/topic2/test-service", nil) - assert.Equal(t, len(fe.chanMappings), 0) - - bus.SendResponseMessage("test-service", "test-message", nil) - assert.Equal(t, len(mockServer.sentMessages), 0) - - // subscribe to valid channel - monitorWg.Add(1) - mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil) - monitorWg.Wait() - assert.Equal(t, len(monitorEvents), 1) - assert.Equal(t, monitorEvents[0].EventType, FabricEndpointSubscribeEvt) - assert.Equal(t, monitorEvents[0].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1) - assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub1"], true) - - // subscribe again to the same channel - monitorWg.Add(1) - mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil) - monitorWg.Wait() - - assert.Equal(t, len(monitorEvents), 2) - assert.Equal(t, monitorEvents[1].EventType, FabricEndpointSubscribeEvt) - assert.Equal(t, monitorEvents[1].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) - assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub2"], true) - - // subscribe to queue channel - monitorWg.Add(1) - mockServer.subscribeHandlerFunction("con1", "sub3", "/user/queue/test-service", nil) - monitorWg.Wait() - assert.Equal(t, len(monitorEvents), 3) - assert.Equal(t, monitorEvents[2].EventType, FabricEndpointSubscribeEvt) - assert.Equal(t, monitorEvents[2].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 3) - assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub3"], true) - - // attempt to subscribe to a protected destination - mockServer.subscribeHandlerFunction("con1", "sub4", "/topic/" + STOMP_SESSION_NOTIFY_CHANNEL, nil) - _, chanMapCreated := fe.chanMappings[STOMP_SESSION_NOTIFY_CHANNEL] - assert.False(t, chanMapCreated) - - mockServer.wg = &sync.WaitGroup{} - mockServer.wg.Add(1) - - bus.SendResponseMessage("test-service", "test-message", nil) - - mockServer.wg.Wait() - - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", []byte{1,2,3}, nil) - mockServer.wg.Wait() - - mockServer.wg.Add(1) - msg := MockStompServerMessage{Destination: "test", Payload: []byte("test-message")} - bus.SendResponseMessage("test-service", msg, nil) - mockServer.wg.Wait() - - mockServer.wg.Add(1) - bus.SendErrorMessage("test-service", errors.New("test-error"), nil) - mockServer.wg.Wait() - - assert.Equal(t, len(mockServer.sentMessages), 4) - assert.Equal(t, mockServer.sentMessages[0].Destination, "/topic/test-service") - assert.Equal(t, string(mockServer.sentMessages[0].Payload), "test-message") - assert.Equal(t, mockServer.sentMessages[1].Payload, []byte{1,2,3}) - - var sentMsg MockStompServerMessage - json.Unmarshal(mockServer.sentMessages[2].Payload, &sentMsg) - assert.Equal(t, msg, sentMsg ) - - assert.Equal(t, string(mockServer.sentMessages[3].Payload), "test-error") - - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", model.Response{ - BrokerDestination: &model.BrokerDestinationConfig{ - Destination: "/user/queue/test-service", - ConnectionId: "con1", - }, - Payload: "test-private-message", - }, nil) - - mockServer.wg.Wait() - - assert.Equal(t, len(mockServer.sentMessages), 5) - assert.Equal(t, mockServer.sentMessages[4].Destination, "/user/queue/test-service") - var sentResponse model.Response - json.Unmarshal(mockServer.sentMessages[4].Payload, &sentResponse) - assert.Equal(t, sentResponse.Payload, "test-private-message") - - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", &model.Response{ - BrokerDestination: &model.BrokerDestinationConfig{ - Destination: "/user/queue/test-service", - ConnectionId: "con1", - }, - Payload: "test-private-message-ptr", - }, nil) - - mockServer.wg.Wait() - - assert.Equal(t, len(mockServer.sentMessages), 6) - assert.Equal(t, mockServer.sentMessages[5].Destination, "/user/queue/test-service") - json.Unmarshal(mockServer.sentMessages[5].Payload, &sentResponse) - assert.Equal(t, sentResponse.Payload, "test-private-message-ptr") + bus := newTestEventBus() + bus.GetChannelManager().CreateChannel(STOMP_SESSION_NOTIFY_CHANNEL) // used for internal channel protection test + fe, mockServer := newTestFabricEndpoint(bus, + EndpointConfig{TopicPrefix: "/topic", UserQueuePrefix: "/user/queue"}) + + bus.GetChannelManager().CreateChannel("test-service") + + monitorWg := sync.WaitGroup{} + var monitorEvents []*MonitorEvent + bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) { + monitorEvents = append(monitorEvents, monitorEvt) + monitorWg.Done() + }, FabricEndpointSubscribeEvt) + + // subscribe to invalid topic + mockServer.subscribeHandlerFunction("con1", "sub1", "/topic2/test-service", nil) + assert.Equal(t, len(fe.chanMappings), 0) + + bus.SendResponseMessage("test-service", "test-message", nil) + assert.Equal(t, len(mockServer.sentMessages), 0) + + // subscribe to valid channel + monitorWg.Add(1) + mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil) + monitorWg.Wait() + assert.Equal(t, len(monitorEvents), 1) + assert.Equal(t, monitorEvents[0].EventType, FabricEndpointSubscribeEvt) + assert.Equal(t, monitorEvents[0].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1) + assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub1"], true) + + // subscribe again to the same channel + monitorWg.Add(1) + mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil) + monitorWg.Wait() + + assert.Equal(t, len(monitorEvents), 2) + assert.Equal(t, monitorEvents[1].EventType, FabricEndpointSubscribeEvt) + assert.Equal(t, monitorEvents[1].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) + assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub2"], true) + + // subscribe to queue channel + monitorWg.Add(1) + mockServer.subscribeHandlerFunction("con1", "sub3", "/user/queue/test-service", nil) + monitorWg.Wait() + assert.Equal(t, len(monitorEvents), 3) + assert.Equal(t, monitorEvents[2].EventType, FabricEndpointSubscribeEvt) + assert.Equal(t, monitorEvents[2].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 3) + assert.Equal(t, fe.chanMappings["test-service"].subs["con1#sub3"], true) + + // attempt to subscribe to a protected destination + mockServer.subscribeHandlerFunction("con1", "sub4", "/topic/"+STOMP_SESSION_NOTIFY_CHANNEL, nil) + _, chanMapCreated := fe.chanMappings[STOMP_SESSION_NOTIFY_CHANNEL] + assert.False(t, chanMapCreated) + + mockServer.wg = &sync.WaitGroup{} + mockServer.wg.Add(1) + + bus.SendResponseMessage("test-service", "test-message", nil) + + mockServer.wg.Wait() + + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", []byte{1, 2, 3}, nil) + mockServer.wg.Wait() + + mockServer.wg.Add(1) + msg := MockStompServerMessage{Destination: "test", Payload: []byte("test-message")} + bus.SendResponseMessage("test-service", msg, nil) + mockServer.wg.Wait() + + mockServer.wg.Add(1) + bus.SendErrorMessage("test-service", errors.New("test-error"), nil) + mockServer.wg.Wait() + + assert.Equal(t, len(mockServer.sentMessages), 4) + assert.Equal(t, mockServer.sentMessages[0].Destination, "/topic/test-service") + assert.Equal(t, string(mockServer.sentMessages[0].Payload), "test-message") + assert.Equal(t, mockServer.sentMessages[1].Payload, []byte{1, 2, 3}) + + var sentMsg MockStompServerMessage + json.Unmarshal(mockServer.sentMessages[2].Payload, &sentMsg) + assert.Equal(t, msg, sentMsg) + + assert.Equal(t, string(mockServer.sentMessages[3].Payload), "test-error") + + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", model.Response{ + BrokerDestination: &model.BrokerDestinationConfig{ + Destination: "/user/queue/test-service", + ConnectionId: "con1", + }, + Payload: "test-private-message", + }, nil) + + mockServer.wg.Wait() + + assert.Equal(t, len(mockServer.sentMessages), 5) + assert.Equal(t, mockServer.sentMessages[4].Destination, "/user/queue/test-service") + var sentResponse model.Response + json.Unmarshal(mockServer.sentMessages[4].Payload, &sentResponse) + assert.Equal(t, sentResponse.Payload, "test-private-message") + + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", &model.Response{ + BrokerDestination: &model.BrokerDestinationConfig{ + Destination: "/user/queue/test-service", + ConnectionId: "con1", + }, + Payload: "test-private-message-ptr", + }, nil) + + mockServer.wg.Wait() + + assert.Equal(t, len(mockServer.sentMessages), 6) + assert.Equal(t, mockServer.sentMessages[5].Destination, "/user/queue/test-service") + json.Unmarshal(mockServer.sentMessages[5].Payload, &sentResponse) + assert.Equal(t, sentResponse.Payload, "test-private-message-ptr") } func TestFabricEndpoint_UnsubscribeEvent(t *testing.T) { - bus := newTestEventBus() - fe, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic"}) - - bus.GetChannelManager().CreateChannel("test-service") - - monitorWg := sync.WaitGroup{} - var monitorEvents []*MonitorEvent - bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) { - monitorEvents = append(monitorEvents, monitorEvt) - monitorWg.Done() - }, FabricEndpointUnsubscribeEvt) - - // subscribe to valid channel - mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil) - mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil) - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) - - mockServer.wg = &sync.WaitGroup{} - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", "test-message", nil) - mockServer.wg.Wait() - assert.Equal(t, len(mockServer.sentMessages), 1) - - - mockServer.unsubscribeHandlerFunction("con1", "sub2", "/invalid-topic/test-service") - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) - - mockServer.unsubscribeHandlerFunction("invalid-con1", "sub2", "/topic/test-service") - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) - - monitorWg.Add(1) - mockServer.unsubscribeHandlerFunction("con1", "sub2", "/topic/test-service") - monitorWg.Wait() - - assert.Equal(t, len(monitorEvents), 1) - assert.Equal(t, monitorEvents[0].EventType, FabricEndpointUnsubscribeEvt) - assert.Equal(t, monitorEvents[0].EntityName, "test-service") - - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1) - - mockServer.wg = &sync.WaitGroup{} - mockServer.wg.Add(1) - bus.SendResponseMessage("test-service", "test-message", nil) - mockServer.wg.Wait() - assert.Equal(t, len(mockServer.sentMessages), 2) - - monitorWg.Add(1) - mockServer.unsubscribeHandlerFunction("con1", "sub1", "/topic/test-service") - monitorWg.Wait() - - assert.Equal(t, len(monitorEvents), 2) - assert.Equal(t, monitorEvents[1].EventType, FabricEndpointUnsubscribeEvt) - assert.Equal(t, monitorEvents[1].EntityName, "test-service") - - - assert.Equal(t, len(fe.chanMappings), 0) - bus.SendResponseMessage("test-service", "test-message", nil) - - // subscribe to non-existing channel - mockServer.subscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel", nil) - assert.Equal(t, len(fe.chanMappings), 1) - assert.Equal(t, len(fe.chanMappings["non-existing-channel"].subs), 1) - assert.Equal(t, fe.chanMappings["non-existing-channel"].autoCreated, true) - assert.True(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel")) - - monitorWg.Add(1) - mockServer.unsubscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel") - monitorWg.Wait() - - assert.Equal(t, len(monitorEvents), 3) - assert.Equal(t, monitorEvents[2].EventType, FabricEndpointUnsubscribeEvt) - assert.Equal(t, monitorEvents[2].EntityName, "non-existing-channel") - - assert.Equal(t, len(fe.chanMappings), 0) - assert.False(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel")) + bus := newTestEventBus() + fe, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic"}) + + bus.GetChannelManager().CreateChannel("test-service") + + monitorWg := sync.WaitGroup{} + var monitorEvents []*MonitorEvent + bus.AddMonitorEventListener(func(monitorEvt *MonitorEvent) { + monitorEvents = append(monitorEvents, monitorEvt) + monitorWg.Done() + }, FabricEndpointUnsubscribeEvt) + + // subscribe to valid channel + mockServer.subscribeHandlerFunction("con1", "sub1", "/topic/test-service", nil) + mockServer.subscribeHandlerFunction("con1", "sub2", "/topic/test-service", nil) + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) + + mockServer.wg = &sync.WaitGroup{} + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", "test-message", nil) + mockServer.wg.Wait() + assert.Equal(t, len(mockServer.sentMessages), 1) + + mockServer.unsubscribeHandlerFunction("con1", "sub2", "/invalid-topic/test-service") + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) + + mockServer.unsubscribeHandlerFunction("invalid-con1", "sub2", "/topic/test-service") + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 2) + + monitorWg.Add(1) + mockServer.unsubscribeHandlerFunction("con1", "sub2", "/topic/test-service") + monitorWg.Wait() + + assert.Equal(t, len(monitorEvents), 1) + assert.Equal(t, monitorEvents[0].EventType, FabricEndpointUnsubscribeEvt) + assert.Equal(t, monitorEvents[0].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["test-service"].subs), 1) + + mockServer.wg = &sync.WaitGroup{} + mockServer.wg.Add(1) + bus.SendResponseMessage("test-service", "test-message", nil) + mockServer.wg.Wait() + assert.Equal(t, len(mockServer.sentMessages), 2) + + monitorWg.Add(1) + mockServer.unsubscribeHandlerFunction("con1", "sub1", "/topic/test-service") + monitorWg.Wait() + + assert.Equal(t, len(monitorEvents), 2) + assert.Equal(t, monitorEvents[1].EventType, FabricEndpointUnsubscribeEvt) + assert.Equal(t, monitorEvents[1].EntityName, "test-service") + + assert.Equal(t, len(fe.chanMappings), 0) + bus.SendResponseMessage("test-service", "test-message", nil) + + // subscribe to non-existing channel + mockServer.subscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel", nil) + assert.Equal(t, len(fe.chanMappings), 1) + assert.Equal(t, len(fe.chanMappings["non-existing-channel"].subs), 1) + assert.Equal(t, fe.chanMappings["non-existing-channel"].autoCreated, true) + assert.True(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel")) + + monitorWg.Add(1) + mockServer.unsubscribeHandlerFunction("con3", "sub1", "/topic/non-existing-channel") + monitorWg.Wait() + + assert.Equal(t, len(monitorEvents), 3) + assert.Equal(t, monitorEvents[2].EventType, FabricEndpointUnsubscribeEvt) + assert.Equal(t, monitorEvents[2].EntityName, "non-existing-channel") + + assert.Equal(t, len(fe.chanMappings), 0) + assert.False(t, bus.GetChannelManager().CheckChannelExists("non-existing-channel")) } func TestFabricEndpoint_BridgeMessage(t *testing.T) { - bus := newTestEventBus() - _, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic", AppRequestPrefix:"/pub", - AppRequestQueuePrefix: "/pub/queue", UserQueuePrefix:"/user/queue" }) - - bus.GetChannelManager().CreateChannel("request-channel") - mh, _ := bus.ListenRequestStream("request-channel") - assert.NotNil(t, mh) + bus := newTestEventBus() + _, mockServer := newTestFabricEndpoint(bus, EndpointConfig{TopicPrefix: "/topic", AppRequestPrefix: "/pub", + AppRequestQueuePrefix: "/pub/queue", UserQueuePrefix: "/user/queue"}) - wg := sync.WaitGroup{} + bus.GetChannelManager().CreateChannel("request-channel") + mh, _ := bus.ListenRequestStream("request-channel") + assert.NotNil(t, mh) - var messages []*model.Message + wg := sync.WaitGroup{} - mh.Handle(func(message *model.Message) { - messages = append(messages, message) - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) + var messages []*model.Message - id1 := uuid.New() - req1, _ := json.Marshal(model.Request{ - Request: "test-request", - Payload: "test-rq", - Id: &id1, - }) + mh.Handle(func(message *model.Message) { + messages = append(messages, message) + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) - wg.Add(1) + id1 := uuid.New() + req1, _ := json.Marshal(model.Request{ + Request: "test-request", + Payload: "test-rq", + Id: &id1, + }) - mockServer.applicationRequestHandlerFunction("/pub/request-channel", req1, "con1") + wg.Add(1) - mockServer.applicationRequestHandlerFunction("/pub2/request-channel", req1, "con1") - mockServer.applicationRequestHandlerFunction("/pub/request-channel-2", req1, "con1") + mockServer.applicationRequestHandlerFunction("/pub/request-channel", req1, "con1") - mockServer.applicationRequestHandlerFunction("/pub/request-channel", []byte("invalid-request-json"), "con1") + mockServer.applicationRequestHandlerFunction("/pub2/request-channel", req1, "con1") + mockServer.applicationRequestHandlerFunction("/pub/request-channel-2", req1, "con1") - id2 := uuid.New() - req2, _ := json.Marshal(model.Request{ - Request: "test-request2", - Payload: "test-rq2", - Id: &id2, - }) + mockServer.applicationRequestHandlerFunction("/pub/request-channel", []byte("invalid-request-json"), "con1") - wg.Wait() + id2 := uuid.New() + req2, _ := json.Marshal(model.Request{ + Request: "test-request2", + Payload: "test-rq2", + Id: &id2, + }) - wg.Add(1) - mockServer.applicationRequestHandlerFunction("/pub/queue/request-channel", req2, "con2") - wg.Wait() + wg.Wait() + wg.Add(1) + mockServer.applicationRequestHandlerFunction("/pub/queue/request-channel", req2, "con2") + wg.Wait() - assert.Equal(t, len(messages), 2) + assert.Equal(t, len(messages), 2) - receivedReq := messages[0].Payload.(*model.Request) + receivedReq := messages[0].Payload.(*model.Request) - assert.Equal(t, receivedReq.Request, "test-request") - assert.Equal(t, receivedReq.Payload, "test-rq") - assert.Equal(t, *receivedReq.Id, id1) - assert.Nil(t, receivedReq.BrokerDestination) + assert.Equal(t, receivedReq.Request, "test-request") + assert.Equal(t, receivedReq.Payload, "test-rq") + assert.Equal(t, *receivedReq.Id, id1) + assert.Nil(t, receivedReq.BrokerDestination) - receivedReq2 := messages[1].Payload.(*model.Request) + receivedReq2 := messages[1].Payload.(*model.Request) - assert.Equal(t, receivedReq2.Request, "test-request2") - assert.Equal(t, receivedReq2.Payload, "test-rq2") - assert.Equal(t, *receivedReq2.Id, id2) - assert.Equal(t, receivedReq2.BrokerDestination.ConnectionId, "con2") - assert.Equal(t, receivedReq2.BrokerDestination.Destination, "/user/queue/request-channel") + assert.Equal(t, receivedReq2.Request, "test-request2") + assert.Equal(t, receivedReq2.Payload, "test-rq2") + assert.Equal(t, *receivedReq2.Id, id2) + assert.Equal(t, receivedReq2.BrokerDestination.ConnectionId, "con2") + assert.Equal(t, receivedReq2.BrokerDestination.Destination, "/user/queue/request-channel") } diff --git a/bus/message_handler.go b/bus/message_handler.go index 702aac6..9945ada 100644 --- a/bus/message_handler.go +++ b/bus/message_handler.go @@ -4,10 +4,10 @@ package bus import ( - "fmt" - "github.com/google/uuid" - "github.com/vmware/transport-go/model" - "sync" + "fmt" + "github.com/google/uuid" + "github.com/vmware/transport-go/model" + "sync" ) // Signature used for all functions used on bus stream APIs to Handle messages. @@ -20,59 +20,59 @@ type MessageErrorFunction func(error) // It also provides a Handle method that accepts a success and error function as handlers. // The Fire method will fire the message queued when using RequestOnce or RequestStream type MessageHandler interface { - GetId() *uuid.UUID - GetDestinationId() *uuid.UUID - Handle(successHandler MessageHandlerFunction, errorHandler MessageErrorFunction) - Fire() error - Close() + GetId() *uuid.UUID + GetDestinationId() *uuid.UUID + Handle(successHandler MessageHandlerFunction, errorHandler MessageErrorFunction) + Fire() error + Close() } type messageHandler struct { - id *uuid.UUID - destination *uuid.UUID - eventCount int64 - closed bool - channel *Channel - requestMessage *model.Message - runCount int64 - ignoreId bool - wrapperFunction MessageHandlerFunction - successHandler MessageHandlerFunction - errorHandler MessageErrorFunction - subscriptionId *uuid.UUID - invokeOnce *sync.Once - channelManager ChannelManager + id *uuid.UUID + destination *uuid.UUID + eventCount int64 + closed bool + channel *Channel + requestMessage *model.Message + runCount int64 + ignoreId bool + wrapperFunction MessageHandlerFunction + successHandler MessageHandlerFunction + errorHandler MessageErrorFunction + subscriptionId *uuid.UUID + invokeOnce *sync.Once + channelManager ChannelManager } func (msgHandler *messageHandler) Handle(successHandler MessageHandlerFunction, errorHandler MessageErrorFunction) { - msgHandler.successHandler = successHandler - msgHandler.errorHandler = errorHandler + msgHandler.successHandler = successHandler + msgHandler.errorHandler = errorHandler - msgHandler.subscriptionId, _ = msgHandler.channelManager.SubscribeChannelHandler( - msgHandler.channel.Name, msgHandler.wrapperFunction, false) + msgHandler.subscriptionId, _ = msgHandler.channelManager.SubscribeChannelHandler( + msgHandler.channel.Name, msgHandler.wrapperFunction, false) } -func (msgHandler *messageHandler) Close() { - if msgHandler.subscriptionId != nil { - msgHandler.channelManager.UnsubscribeChannelHandler( - msgHandler.channel.Name, msgHandler.subscriptionId) - } +func (msgHandler *messageHandler) Close() { + if msgHandler.subscriptionId != nil { + msgHandler.channelManager.UnsubscribeChannelHandler( + msgHandler.channel.Name, msgHandler.subscriptionId) + } } func (msgHandler *messageHandler) GetId() *uuid.UUID { - return msgHandler.id + return msgHandler.id } func (msgHandler *messageHandler) GetDestinationId() *uuid.UUID { - return msgHandler.destination + return msgHandler.destination } func (msgHandler *messageHandler) Fire() error { - if msgHandler.requestMessage != nil { - sendMessageToChannel(msgHandler.channel, msgHandler.requestMessage) - msgHandler.channel.wg.Wait() - return nil - } else { - return fmt.Errorf("nothing to fire, request is empty") - } + if msgHandler.requestMessage != nil { + sendMessageToChannel(msgHandler.channel, msgHandler.requestMessage) + msgHandler.channel.wg.Wait() + return nil + } else { + return fmt.Errorf("nothing to fire, request is empty") + } } diff --git a/bus/message_test.go b/bus/message_test.go index bade993..e80ae1e 100644 --- a/bus/message_test.go +++ b/bus/message_test.go @@ -4,21 +4,20 @@ package bus import ( - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/vmware/transport-go/model" - "testing" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/vmware/transport-go/model" + "testing" ) func TestMessageModel(t *testing.T) { - id := uuid.New() - var message = &model.Message{ - Id: &id, - Payload: "A new message", - Channel: "123", - Direction: model.RequestDir} - assert.Equal(t, "A new message", message.Payload) - assert.Equal(t, model.RequestDir, message.Direction, ) - assert.Equal(t, message.Channel, "123") + id := uuid.New() + var message = &model.Message{ + Id: &id, + Payload: "A new message", + Channel: "123", + Direction: model.RequestDir} + assert.Equal(t, "A new message", message.Payload) + assert.Equal(t, model.RequestDir, message.Direction) + assert.Equal(t, message.Channel, "123") } - diff --git a/bus/monitor_event.go b/bus/monitor_event.go index c299bb3..9004c88 100644 --- a/bus/monitor_event.go +++ b/bus/monitor_event.go @@ -4,32 +4,33 @@ package bus type MonitorEventType int32 + const ( - ChannelCreatedEvt MonitorEventType = iota - ChannelDestroyedEvt - ChannelSubscriberJoinedEvt - ChannelSubscriberLeftEvt - StoreCreatedEvt - StoreDestroyedEvt - StoreInitializedEvt - BrokerSubscribedEvt - BrokerUnsubscribedEvt - FabricEndpointSubscribeEvt - FabricEndpointUnsubscribeEvt + ChannelCreatedEvt MonitorEventType = iota + ChannelDestroyedEvt + ChannelSubscriberJoinedEvt + ChannelSubscriberLeftEvt + StoreCreatedEvt + StoreDestroyedEvt + StoreInitializedEvt + BrokerSubscribedEvt + BrokerUnsubscribedEvt + FabricEndpointSubscribeEvt + FabricEndpointUnsubscribeEvt ) type MonitorEventHandler func(event *MonitorEvent) type MonitorEvent struct { - // Type of the event - EventType MonitorEventType - // The name of the channel or the store related to this event - EntityName string - // Optional event data - Data interface{} + // Type of the event + EventType MonitorEventType + // The name of the channel or the store related to this event + EntityName string + // Optional event data + Data interface{} } // Create a new monitor event -func NewMonitorEvent(evtType MonitorEventType, entityName string, data interface{} ) *MonitorEvent { - return &MonitorEvent{EventType: evtType, Data: data, EntityName: entityName} +func NewMonitorEvent(evtType MonitorEventType, entityName string, data interface{}) *MonitorEvent { + return &MonitorEvent{EventType: evtType, Data: data, EntityName: entityName} } diff --git a/bus/mutation_store_stream.go b/bus/mutation_store_stream.go index e1cc289..643c797 100644 --- a/bus/mutation_store_stream.go +++ b/bus/mutation_store_stream.go @@ -4,97 +4,97 @@ package bus import ( - "sync" - "fmt" + "fmt" + "sync" ) type MutationRequest struct { - Request interface{} - RequestType interface{} - SuccessHandler func(interface{}) - ErrorHandler func(interface{}) + Request interface{} + RequestType interface{} + SuccessHandler func(interface{}) + ErrorHandler func(interface{}) } type MutationRequestHandlerFunction func(mutationReq *MutationRequest) // Interface for subscribing for mutation requests type MutationStoreStream interface { - // Subscribe to the mutation requests stream. - Subscribe(handler MutationRequestHandlerFunction) error - // Unsubscribe from the stream. - Unsubscribe() error + // Subscribe to the mutation requests stream. + Subscribe(handler MutationRequestHandlerFunction) error + // Unsubscribe from the stream. + Unsubscribe() error } type mutationStreamFilter struct { - requestTypes []interface{} + requestTypes []interface{} } func (f *mutationStreamFilter) match(mutationReq *MutationRequest) bool { - if len(f.requestTypes) == 0 { - return true - } + if len(f.requestTypes) == 0 { + return true + } - for _, s := range f.requestTypes { - if mutationReq.RequestType == s { - return true - } - } + for _, s := range f.requestTypes { + if mutationReq.RequestType == s { + return true + } + } - return false + return false } type mutationStoreStream struct { - handler MutationRequestHandlerFunction - lock sync.RWMutex - store *busStore - filter *mutationStreamFilter + handler MutationRequestHandlerFunction + lock sync.RWMutex + store *busStore + filter *mutationStreamFilter } func newMutationStoreStream(store *busStore, filter *mutationStreamFilter) *mutationStoreStream { - stream := new(mutationStoreStream) - stream.store = store - stream.filter = filter - return stream + stream := new(mutationStoreStream) + stream.store = store + stream.filter = filter + return stream } func (ms *mutationStoreStream) Subscribe(handler MutationRequestHandlerFunction) error { - if handler == nil { - return fmt.Errorf("invalid MutationRequestHandlerFunction") - } - - ms.lock.Lock() - if ms.handler != nil { - ms.lock.Unlock() - return fmt.Errorf("stream already subscribed") - } - ms.handler = handler - ms.lock.Unlock() - - ms.store.onMutationStreamSubscribe(ms) - return nil + if handler == nil { + return fmt.Errorf("invalid MutationRequestHandlerFunction") + } + + ms.lock.Lock() + if ms.handler != nil { + ms.lock.Unlock() + return fmt.Errorf("stream already subscribed") + } + ms.handler = handler + ms.lock.Unlock() + + ms.store.onMutationStreamSubscribe(ms) + return nil } func (ms *mutationStoreStream) Unsubscribe() error { - ms.lock.Lock() - if ms.handler == nil { - ms.lock.Unlock() - return fmt.Errorf("stream not subscribed") - } - ms.handler = nil - ms.lock.Unlock() - - ms.store.onMutationStreamUnsubscribe(ms) - return nil + ms.lock.Lock() + if ms.handler == nil { + ms.lock.Unlock() + return fmt.Errorf("stream not subscribed") + } + ms.handler = nil + ms.lock.Unlock() + + ms.store.onMutationStreamUnsubscribe(ms) + return nil } func (ms *mutationStoreStream) onMutationRequest(mutationReq *MutationRequest) { - if !ms.filter.match(mutationReq) { - return - } - - ms.lock.RLock() - defer ms.lock.RUnlock() - if ms.handler != nil { - go ms.handler(mutationReq) - } + if !ms.filter.match(mutationReq) { + return + } + + ms.lock.RLock() + defer ms.lock.RUnlock() + if ms.handler != nil { + go ms.handler(mutationReq) + } } diff --git a/bus/store.go b/bus/store.go index 74b26c7..52ed1d5 100644 --- a/bus/store.go +++ b/bus/store.go @@ -4,551 +4,551 @@ package bus import ( - "encoding/json" - "fmt" - "github.com/google/uuid" - "github.com/vmware/transport-go/log" - "github.com/vmware/transport-go/model" - "reflect" - "sync" + "encoding/json" + "fmt" + "github.com/google/uuid" + "github.com/vmware/transport-go/log" + "github.com/vmware/transport-go/model" + "reflect" + "sync" ) // Describes a single store item change type StoreChange struct { - Id string // the id of the updated item - Value interface{} // the updated value of the item - State interface{} // state associated with this change - IsDeleteChange bool // true if the item was removed from the store - StoreVersion int64 // the store's version when this change was made + Id string // the id of the updated item + Value interface{} // the updated value of the item + State interface{} // state associated with this change + IsDeleteChange bool // true if the item was removed from the store + StoreVersion int64 // the store's version when this change was made } // BusStore is a stateful in memory cache for objects. All state changes (any time the cache is modified) // will broadcast that updated object to any subscribers of the BusStore for those specific objects // or all objects of a certain type and state changes. type BusStore interface { - // Get the name (the id) of the store. - GetName() string - // Add new or updates existing item in the store. - Put(id string, value interface{}, state interface{}) - // Returns an item from the store and a boolean flag - // indicating whether the item exists - Get(id string) (interface{}, bool) - // Shorten version of the Get() method, returns only the item value. - GetValue(id string) interface{} - // Remove an item from the store. Returns true if the remove operation was successful. - Remove(id string, state interface{}) bool - // Return a slice containing all store items. - AllValues() []interface{} - // Return a map with all items from the store. - AllValuesAsMap() map[string]interface{} - // Return a map with all items from the store with the current store version. - AllValuesAndVersion() (map[string]interface{}, int64) - // Subscribe to state changes for a specific object. - OnChange(id string, state ...interface{}) StoreStream - // Subscribe to state changes for all objects - OnAllChanges(state ...interface{}) StoreStream - // Notify when the store has been initialize (via populate() or initialize() - WhenReady(readyFunction func()) - // Populate the store with a map of items and their ID's. - Populate(items map[string]interface{}) error - // Mark the store as initialized and notify all watchers. - Initialize() - // Subscribe to mutation requests made via mutate() method. - OnMutationRequest(mutationType ...interface{}) MutationStoreStream - // Send a mutation request to any subscribers handling mutations. - Mutate(request interface{}, requestType interface{}, - successHandler func(interface{}), errorHandler func(interface{})) - // Removes all items from the store and change its state to uninitialized". - Reset() - // Returns true if this is galactic store. - IsGalactic() bool - // Get the item type if such is specified during the creation of the - // store - GetItemType() reflect.Type + // Get the name (the id) of the store. + GetName() string + // Add new or updates existing item in the store. + Put(id string, value interface{}, state interface{}) + // Returns an item from the store and a boolean flag + // indicating whether the item exists + Get(id string) (interface{}, bool) + // Shorten version of the Get() method, returns only the item value. + GetValue(id string) interface{} + // Remove an item from the store. Returns true if the remove operation was successful. + Remove(id string, state interface{}) bool + // Return a slice containing all store items. + AllValues() []interface{} + // Return a map with all items from the store. + AllValuesAsMap() map[string]interface{} + // Return a map with all items from the store with the current store version. + AllValuesAndVersion() (map[string]interface{}, int64) + // Subscribe to state changes for a specific object. + OnChange(id string, state ...interface{}) StoreStream + // Subscribe to state changes for all objects + OnAllChanges(state ...interface{}) StoreStream + // Notify when the store has been initialize (via populate() or initialize() + WhenReady(readyFunction func()) + // Populate the store with a map of items and their ID's. + Populate(items map[string]interface{}) error + // Mark the store as initialized and notify all watchers. + Initialize() + // Subscribe to mutation requests made via mutate() method. + OnMutationRequest(mutationType ...interface{}) MutationStoreStream + // Send a mutation request to any subscribers handling mutations. + Mutate(request interface{}, requestType interface{}, + successHandler func(interface{}), errorHandler func(interface{})) + // Removes all items from the store and change its state to uninitialized". + Reset() + // Returns true if this is galactic store. + IsGalactic() bool + // Get the item type if such is specified during the creation of the + // store + GetItemType() reflect.Type } // Internal BusStore implementation type busStore struct { - name string - itemsLock sync.RWMutex - items map[string]interface{} - storeVersion int64 - storeStreamsLock sync.RWMutex - storeStreams []*storeStream - mutationStreamsLock sync.RWMutex - mutationStreams []*mutationStoreStream - initializer sync.Once - readyC chan struct{} - isGalactic bool - galacticConf *galacticStoreConfig - bus EventBus - itemType reflect.Type - storeSynHandler MessageHandler + name string + itemsLock sync.RWMutex + items map[string]interface{} + storeVersion int64 + storeStreamsLock sync.RWMutex + storeStreams []*storeStream + mutationStreamsLock sync.RWMutex + mutationStreams []*mutationStoreStream + initializer sync.Once + readyC chan struct{} + isGalactic bool + galacticConf *galacticStoreConfig + bus EventBus + itemType reflect.Type + storeSynHandler MessageHandler } type galacticStoreConfig struct { - syncChannelConfig *storeSyncChannelConfig + syncChannelConfig *storeSyncChannelConfig } func newBusStore(name string, bus EventBus, itemType reflect.Type, galacticConf *galacticStoreConfig) BusStore { - store := new(busStore) - store.name = name - store.bus = bus - store.itemType = itemType - store.galacticConf = galacticConf + store := new(busStore) + store.name = name + store.bus = bus + store.itemType = itemType + store.galacticConf = galacticConf - initStore(store) + initStore(store) - store.isGalactic = galacticConf != nil + store.isGalactic = galacticConf != nil - if store.isGalactic { - initGalacticStore(store) - } + if store.isGalactic { + initGalacticStore(store) + } - return store + return store } func initStore(store *busStore) { - store.readyC = make(chan struct{}) - store.storeStreams = []*storeStream {} - store.mutationStreams = []*mutationStoreStream {} - store.items = make(map[string]interface{}) - store.storeVersion = 1 - store.initializer = sync.Once{} + store.readyC = make(chan struct{}) + store.storeStreams = []*storeStream{} + store.mutationStreams = []*mutationStoreStream{} + store.items = make(map[string]interface{}) + store.storeVersion = 1 + store.initializer = sync.Once{} } func initGalacticStore(store *busStore) { - syncChannelConf := store.galacticConf.syncChannelConfig - - var err error - store.storeSynHandler, err = store.bus.ListenStream(syncChannelConf.syncChannelName) - if err != nil { - return - } - - store.storeSynHandler.Handle( - func(msg *model.Message) { - d := msg.Payload.([]byte) - var storeResponse map[string]interface{} - - err := json.Unmarshal(d, &storeResponse) - if err != nil { - log.Warn("failed to unmarshal storeResponse") - return - } - - if storeResponse["storeId"] != store.GetName() { - // the response is for another store - return - } - - responseType := storeResponse["responseType"].(string) - - switch responseType { - case "storeContentResponse": - - store.itemsLock.Lock() - defer store.itemsLock.Unlock() - - store.updateVersionFromResponse(storeResponse) - items := storeResponse["items"].(map[string]interface{}) - store.items = make(map[string]interface{}) - for key, val := range items { - deserializedValue, err := store.deserializeRawValue(val) - if err != nil { - log.Warn("failed to deserialize store item value %e", err) - continue - } else { - store.items[key] = deserializedValue - } - } - store.Initialize() - case "updateStoreResponse": - - store.itemsLock.Lock() - defer store.itemsLock.Unlock() - - store.updateVersionFromResponse(storeResponse) - newItemRaw, ok := storeResponse["newItemValue"] - itemId := storeResponse["itemId"].(string) - if !ok || newItemRaw == nil { - store.removeInternal(itemId, "galacticSyncRemove") - } else { - newItemValue, err := store.deserializeRawValue(newItemRaw) - if err != nil { - log.Warn("failed to deserialize store item value %e", err) - return - } - store.putInternal(itemId, newItemValue, "galacticSyncUpdate") - } - } - }, - func(e error) { - }) - - store.sendOpenStoreRequest() + syncChannelConf := store.galacticConf.syncChannelConfig + + var err error + store.storeSynHandler, err = store.bus.ListenStream(syncChannelConf.syncChannelName) + if err != nil { + return + } + + store.storeSynHandler.Handle( + func(msg *model.Message) { + d := msg.Payload.([]byte) + var storeResponse map[string]interface{} + + err := json.Unmarshal(d, &storeResponse) + if err != nil { + log.Warn("failed to unmarshal storeResponse") + return + } + + if storeResponse["storeId"] != store.GetName() { + // the response is for another store + return + } + + responseType := storeResponse["responseType"].(string) + + switch responseType { + case "storeContentResponse": + + store.itemsLock.Lock() + defer store.itemsLock.Unlock() + + store.updateVersionFromResponse(storeResponse) + items := storeResponse["items"].(map[string]interface{}) + store.items = make(map[string]interface{}) + for key, val := range items { + deserializedValue, err := store.deserializeRawValue(val) + if err != nil { + log.Warn("failed to deserialize store item value %e", err) + continue + } else { + store.items[key] = deserializedValue + } + } + store.Initialize() + case "updateStoreResponse": + + store.itemsLock.Lock() + defer store.itemsLock.Unlock() + + store.updateVersionFromResponse(storeResponse) + newItemRaw, ok := storeResponse["newItemValue"] + itemId := storeResponse["itemId"].(string) + if !ok || newItemRaw == nil { + store.removeInternal(itemId, "galacticSyncRemove") + } else { + newItemValue, err := store.deserializeRawValue(newItemRaw) + if err != nil { + log.Warn("failed to deserialize store item value %e", err) + return + } + store.putInternal(itemId, newItemValue, "galacticSyncUpdate") + } + } + }, + func(e error) { + }) + + store.sendOpenStoreRequest() } func (store *busStore) updateVersionFromResponse(storeResponse map[string]interface{}) { - version := storeResponse["storeVersion"] - switch version.(type) { - case float64: - store.storeVersion = int64(version.(float64)) - case int64: - store.storeVersion = version.(int64) - default: - log.Warn("failed to deserialize store version") - store.storeVersion = 1 - } + version := storeResponse["storeVersion"] + switch version.(type) { + case float64: + store.storeVersion = int64(version.(float64)) + case int64: + store.storeVersion = version.(int64) + default: + log.Warn("failed to deserialize store version") + store.storeVersion = 1 + } } func (store *busStore) deserializeRawValue(rawValue interface{}) (interface{}, error) { - return model.ConvertValueToType(rawValue, store.itemType) + return model.ConvertValueToType(rawValue, store.itemType) } func (store *busStore) sendOpenStoreRequest() { - openStoreReq := map[string]string { - "storeId": store.GetName(), - } - store.sendGalacticRequest("openStore", openStoreReq) + openStoreReq := map[string]string{ + "storeId": store.GetName(), + } + store.sendGalacticRequest("openStore", openStoreReq) } func (store *busStore) sendGalacticRequest(requestCmd string, requestPayload interface{}) { - // create request - id := uuid.New(); - r := &model.Request{} - r.Request = requestCmd - r.Payload = requestPayload - r.Id = &id - jsonReq, _ := json.Marshal(r) + // create request + id := uuid.New() + r := &model.Request{} + r.Request = requestCmd + r.Payload = requestPayload + r.Id = &id + jsonReq, _ := json.Marshal(r) - syncChannelConfig := store.galacticConf.syncChannelConfig + syncChannelConfig := store.galacticConf.syncChannelConfig - // send request. - syncChannelConfig.conn.SendJSONMessage( - syncChannelConfig.pubPrefix + syncChannelConfig.syncChannelName, - jsonReq) + // send request. + syncChannelConfig.conn.SendJSONMessage( + syncChannelConfig.pubPrefix+syncChannelConfig.syncChannelName, + jsonReq) } func (store *busStore) sendCloseStoreRequest() { - closeStoreReq := map[string]string { - "storeId": store.GetName(), - } - store.sendGalacticRequest("closeStore", closeStoreReq) + closeStoreReq := map[string]string{ + "storeId": store.GetName(), + } + store.sendGalacticRequest("closeStore", closeStoreReq) } func (store *busStore) OnDestroy() { - if store.IsGalactic() { - store.sendCloseStoreRequest() - if store.storeSynHandler != nil { - store.storeSynHandler.Close() - } - } + if store.IsGalactic() { + store.sendCloseStoreRequest() + if store.storeSynHandler != nil { + store.storeSynHandler.Close() + } + } } func (store *busStore) IsGalactic() bool { - return store.isGalactic + return store.isGalactic } func (store *busStore) GetItemType() reflect.Type { - return store.itemType + return store.itemType } func (store *busStore) GetName() string { - return store.name + return store.name } func (store *busStore) Populate(items map[string]interface{}) error { - if store.IsGalactic() { - return fmt.Errorf("populate() API is not supported for galactic stores") - } + if store.IsGalactic() { + return fmt.Errorf("populate() API is not supported for galactic stores") + } - store.itemsLock.Lock() - defer store.itemsLock.Unlock() + store.itemsLock.Lock() + defer store.itemsLock.Unlock() - if len(store.items) > 0 { - return fmt.Errorf("store items already initialized") - } + if len(store.items) > 0 { + return fmt.Errorf("store items already initialized") + } - for k,v := range items { - store.items[k] = v - } - store.Initialize() - return nil + for k, v := range items { + store.items[k] = v + } + store.Initialize() + return nil } func (store *busStore) Put(id string, value interface{}, state interface{}) { - if store.IsGalactic() { - store.putGalactic(id, value) - } else { - store.itemsLock.Lock() - defer store.itemsLock.Unlock() + if store.IsGalactic() { + store.putGalactic(id, value) + } else { + store.itemsLock.Lock() + defer store.itemsLock.Unlock() - store.putInternal(id, value, state) - } + store.putInternal(id, value, state) + } } func (store *busStore) putGalactic(id string, value interface{}) { - store.itemsLock.RLock() - clientStoreVersion := store.storeVersion - store.itemsLock.RUnlock() + store.itemsLock.RLock() + clientStoreVersion := store.storeVersion + store.itemsLock.RUnlock() - store.sendUpdateStoreRequest(id, value, clientStoreVersion) + store.sendUpdateStoreRequest(id, value, clientStoreVersion) } func (store *busStore) sendUpdateStoreRequest(id string, value interface{}, storeVersion int64) { - updateReq := map[string]interface{} { - "storeId": store.GetName(), - "clientStoreVersion": storeVersion, - "itemId": id, - "newItemValue": value, - } + updateReq := map[string]interface{}{ + "storeId": store.GetName(), + "clientStoreVersion": storeVersion, + "itemId": id, + "newItemValue": value, + } - store.sendGalacticRequest("updateStore", updateReq) + store.sendGalacticRequest("updateStore", updateReq) } func (store *busStore) putInternal(id string, value interface{}, state interface{}) { - if !store.IsGalactic() { - store.storeVersion++ - } - store.items[id] = value + if !store.IsGalactic() { + store.storeVersion++ + } + store.items[id] = value - change := &StoreChange{ - Id: id, - State: state, - Value: value, - StoreVersion: store.storeVersion, - } + change := &StoreChange{ + Id: id, + State: state, + Value: value, + StoreVersion: store.storeVersion, + } - go store.onStoreChange(change) + go store.onStoreChange(change) } func (store *busStore) Get(id string) (interface{}, bool) { - store.itemsLock.RLock() - defer store.itemsLock.RUnlock() + store.itemsLock.RLock() + defer store.itemsLock.RUnlock() - val, ok := store.items[id] + val, ok := store.items[id] - return val, ok + return val, ok } func (store *busStore) GetValue(id string) interface{} { - val, _ := store.Get(id) - return val + val, _ := store.Get(id) + return val } func (store *busStore) Remove(id string, state interface{}) bool { - if store.IsGalactic() { - return store.removeGalactic(id) - } else { - store.itemsLock.Lock() - defer store.itemsLock.Unlock() + if store.IsGalactic() { + return store.removeGalactic(id) + } else { + store.itemsLock.Lock() + defer store.itemsLock.Unlock() - return store.removeInternal(id, state) - } + return store.removeInternal(id, state) + } } func (store *busStore) removeGalactic(id string) bool { - store.itemsLock.RLock() - _, ok := store.items[id] - storeVersion := store.storeVersion - store.itemsLock.RUnlock() + store.itemsLock.RLock() + _, ok := store.items[id] + storeVersion := store.storeVersion + store.itemsLock.RUnlock() - if ok { - store.sendUpdateStoreRequest(id, nil, storeVersion) - return true - } - return false + if ok { + store.sendUpdateStoreRequest(id, nil, storeVersion) + return true + } + return false } func (store *busStore) removeInternal(id string, state interface{}) bool { - value, ok := store.items[id] - if !ok { - return false - } + value, ok := store.items[id] + if !ok { + return false + } - if !store.IsGalactic() { - store.storeVersion++ - } - delete(store.items, id) + if !store.IsGalactic() { + store.storeVersion++ + } + delete(store.items, id) - change := &StoreChange{ - Id: id, - State: state, - Value: value, - StoreVersion: store.storeVersion, - IsDeleteChange: true, - } + change := &StoreChange{ + Id: id, + State: state, + Value: value, + StoreVersion: store.storeVersion, + IsDeleteChange: true, + } - go store.onStoreChange(change) - return true + go store.onStoreChange(change) + return true } func (store *busStore) AllValues() []interface{} { - store.itemsLock.RLock() - defer store.itemsLock.RUnlock() + store.itemsLock.RLock() + defer store.itemsLock.RUnlock() - values := make([] interface{}, 0, len(store.items)) - for _, value := range store.items { - values = append(values, value) - } + values := make([]interface{}, 0, len(store.items)) + for _, value := range store.items { + values = append(values, value) + } - return values + return values } func (store *busStore) AllValuesAsMap() map[string]interface{} { - store.itemsLock.RLock() - defer store.itemsLock.RUnlock() + store.itemsLock.RLock() + defer store.itemsLock.RUnlock() - values := make(map[string] interface{}) + values := make(map[string]interface{}) - for key, value := range store.items { - values[key] = value - } + for key, value := range store.items { + values[key] = value + } - return values + return values } func (store *busStore) AllValuesAndVersion() (map[string]interface{}, int64) { - store.itemsLock.RLock() - defer store.itemsLock.RUnlock() + store.itemsLock.RLock() + defer store.itemsLock.RUnlock() - values := make(map[string] interface{}) + values := make(map[string]interface{}) - for key, value := range store.items { - values[key] = value - } + for key, value := range store.items { + values[key] = value + } - return values, store.storeVersion + return values, store.storeVersion } func (store *busStore) OnMutationRequest(requestType ...interface{}) MutationStoreStream { - return newMutationStoreStream(store, &mutationStreamFilter{ - requestTypes: requestType, - }) + return newMutationStoreStream(store, &mutationStreamFilter{ + requestTypes: requestType, + }) } func (store *busStore) Mutate(request interface{}, requestType interface{}, - successHandler func(interface{}), errorHandler func(interface{})) { + successHandler func(interface{}), errorHandler func(interface{})) { - store.mutationStreamsLock.RLock() - defer store.mutationStreamsLock.RUnlock() + store.mutationStreamsLock.RLock() + defer store.mutationStreamsLock.RUnlock() - for _, ms := range store.mutationStreams { - ms.onMutationRequest(&MutationRequest{ - Request: request, - RequestType: requestType, - SuccessHandler: successHandler, - ErrorHandler: errorHandler, - }) - } + for _, ms := range store.mutationStreams { + ms.onMutationRequest(&MutationRequest{ + Request: request, + RequestType: requestType, + SuccessHandler: successHandler, + ErrorHandler: errorHandler, + }) + } } -func(store *busStore) onStoreChange(change *StoreChange) { - store.storeStreamsLock.RLock() - defer store.storeStreamsLock.RUnlock() +func (store *busStore) onStoreChange(change *StoreChange) { + store.storeStreamsLock.RLock() + defer store.storeStreamsLock.RUnlock() - for _, storeStream := range store.storeStreams { - storeStream.onStoreChange(change) - } + for _, storeStream := range store.storeStreams { + storeStream.onStoreChange(change) + } } func (store *busStore) Initialize() { - store.initializer.Do(func() { - close(store.readyC) - store.bus.SendMonitorEvent(StoreInitializedEvt, store.name, nil) - }) + store.initializer.Do(func() { + close(store.readyC) + store.bus.SendMonitorEvent(StoreInitializedEvt, store.name, nil) + }) } func (store *busStore) Reset() { - store.itemsLock.Lock() - defer store.itemsLock.Unlock() + store.itemsLock.Lock() + defer store.itemsLock.Unlock() - store.mutationStreamsLock.Lock() - defer store.mutationStreamsLock.Unlock() + store.mutationStreamsLock.Lock() + defer store.mutationStreamsLock.Unlock() - store.storeStreamsLock.Lock() - defer store.storeStreamsLock.Unlock() + store.storeStreamsLock.Lock() + defer store.storeStreamsLock.Unlock() - initStore(store) + initStore(store) - if (store.IsGalactic()) { - store.sendOpenStoreRequest() - } + if store.IsGalactic() { + store.sendOpenStoreRequest() + } } func (store *busStore) WhenReady(readyFunc func()) { - go func() { - <- store.readyC - readyFunc() - }() + go func() { + <-store.readyC + readyFunc() + }() } func (store *busStore) OnChange(id string, state ...interface{}) StoreStream { - return newStoreStream(store, &streamFilter{ - itemId: id, - states: state, - }) + return newStoreStream(store, &streamFilter{ + itemId: id, + states: state, + }) } func (store *busStore) OnAllChanges(state ...interface{}) StoreStream { - return newStoreStream(store, &streamFilter{ - states: state, - matchAllItems: true, - }) + return newStoreStream(store, &streamFilter{ + states: state, + matchAllItems: true, + }) } func (store *busStore) onStreamSubscribe(stream *storeStream) { - store.storeStreamsLock.Lock() - defer store.storeStreamsLock.Unlock() + store.storeStreamsLock.Lock() + defer store.storeStreamsLock.Unlock() - store.storeStreams = append(store.storeStreams, stream) + store.storeStreams = append(store.storeStreams, stream) } func (store *busStore) onMutationStreamSubscribe(stream *mutationStoreStream) { - store.mutationStreamsLock.Lock() - defer store.mutationStreamsLock.Unlock() + store.mutationStreamsLock.Lock() + defer store.mutationStreamsLock.Unlock() - store.mutationStreams = append(store.mutationStreams, stream) + store.mutationStreams = append(store.mutationStreams, stream) } func (store *busStore) onStreamUnsubscribe(stream *storeStream) { - store.storeStreamsLock.Lock() - defer store.storeStreamsLock.Unlock() + store.storeStreamsLock.Lock() + defer store.storeStreamsLock.Unlock() - var i int - var s *storeStream - for i, s = range store.storeStreams { - if s == stream { - break - } - } + var i int + var s *storeStream + for i, s = range store.storeStreams { + if s == stream { + break + } + } - if s == stream { - n := len(store.storeStreams) - store.storeStreams[i] = store.storeStreams[n-1] - store.storeStreams = store.storeStreams[:n-1] - } + if s == stream { + n := len(store.storeStreams) + store.storeStreams[i] = store.storeStreams[n-1] + store.storeStreams = store.storeStreams[:n-1] + } } func (store *busStore) onMutationStreamUnsubscribe(stream *mutationStoreStream) { - store.mutationStreamsLock.Lock() - defer store.mutationStreamsLock.Unlock() - - var i int - var s *mutationStoreStream - for i, s = range store.mutationStreams { - if s == stream { - break - } - } - - if s == stream { - n := len(store.mutationStreams) - store.mutationStreams[i] = store.mutationStreams[n-1] - store.mutationStreams = store.mutationStreams[:n-1] - } + store.mutationStreamsLock.Lock() + defer store.mutationStreamsLock.Unlock() + + var i int + var s *mutationStoreStream + for i, s = range store.mutationStreams { + if s == stream { + break + } + } + + if s == stream { + n := len(store.mutationStreams) + store.mutationStreams[i] = store.mutationStreams[n-1] + store.mutationStreams = store.mutationStreams[:n-1] + } } diff --git a/bus/store_manager_test.go b/bus/store_manager_test.go index 323b823..5da4e45 100644 --- a/bus/store_manager_test.go +++ b/bus/store_manager_test.go @@ -4,150 +4,150 @@ package bus import ( - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "reflect" - "sync" - "sync/atomic" - "testing" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "reflect" + "sync" + "sync/atomic" + "testing" ) func createTestStoreManager() StoreManager { - return newStoreManager(GetBus()) + return newStoreManager(GetBus()) } func TestStoreManager_CreateStore(t *testing.T) { - storeManager := createTestStoreManager() - assert.NotNil(t, storeManager) - - wg := sync.WaitGroup{} - - for i := 0; i < 50; i++ { - wg.Add(1) - go func() { - store := storeManager.CreateStore("testStore") - assert.NotNil(t, store) - assert.Equal(t, store.GetName(), "testStore") - assert.Equal(t, store, storeManager.GetStore("testStore")) - wg.Done() - }() - } - - wg.Wait() - - store2 := storeManager.CreateStore("testStore2") - assert.NotEqual(t, store2, storeManager.GetStore("testStore")) + storeManager := createTestStoreManager() + assert.NotNil(t, storeManager) + + wg := sync.WaitGroup{} + + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + store := storeManager.CreateStore("testStore") + assert.NotNil(t, store) + assert.Equal(t, store.GetName(), "testStore") + assert.Equal(t, store, storeManager.GetStore("testStore")) + wg.Done() + }() + } + + wg.Wait() + + store2 := storeManager.CreateStore("testStore2") + assert.NotEqual(t, store2, storeManager.GetStore("testStore")) } func TestStoreManager_GetStore(t *testing.T) { - storeManager := createTestStoreManager() - storeManager.CreateStore("testStore") - store := storeManager.GetStore("testStore") - assert.Equal(t, store.GetName(), "testStore") + storeManager := createTestStoreManager() + storeManager.CreateStore("testStore") + store := storeManager.GetStore("testStore") + assert.Equal(t, store.GetName(), "testStore") - assert.Nil(t, storeManager.GetStore("invalid-store")) + assert.Nil(t, storeManager.GetStore("invalid-store")) } func TestStoreManager_DestroyStore(t *testing.T) { - storeManager := createTestStoreManager() - storeManager.CreateStore("testStore") + storeManager := createTestStoreManager() + storeManager.CreateStore("testStore") - var counter int32 = 0 + var counter int32 = 0 - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - for i := 0; i < 50; i++ { - wg.Add(1) - go func() { - if storeManager.DestroyStore("testStore") { - atomic.AddInt32(&counter, 1) - } - assert.Nil(t, storeManager.GetStore("testStore")) - wg.Done() - }() - } + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + if storeManager.DestroyStore("testStore") { + atomic.AddInt32(&counter, 1) + } + assert.Nil(t, storeManager.GetStore("testStore")) + wg.Done() + }() + } - wg.Wait() + wg.Wait() - // Verify that only one of the DestroyStore calls was successful (has returned true) - assert.Equal(t, counter, int32(1)) + // Verify that only one of the DestroyStore calls was successful (has returned true) + assert.Equal(t, counter, int32(1)) } func TestStoreManager_ConfigureStoreSyncChannel(t *testing.T) { - m := createTestStoreManager() - id := uuid.New() - con := &MockBridgeConnection{Id: &id} - - subId := uuid.New() - s := &MockBridgeSubscription{ - Id: &subId, - } - syncChannelDst := "/topic-prefix/transport-store-sync." + id.String() - con.On("Subscribe", syncChannelDst).Return(s, nil) - con.On("SendMessage", syncChannelDst, mock.Anything).Return(nil) - m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") - - storeManagerImpl := m.(*storeManager) - - conf, ok := storeManagerImpl.syncChannels[id] - assert.True(t, ok) - assert.NotNil(t, conf.syncChannelName) - assert.Equal(t, conf.pubPrefix, "/pub-prefix/") - assert.Equal(t, conf.topicPrefix, "/topic-prefix/") - assert.Equal(t, conf.conn, con) - - syncCh, _ := storeManagerImpl.eventBus.GetChannelManager().GetChannel(conf.syncChannelName) - assert.True(t, syncCh.galactic) - - err := m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") - assert.EqualError(t, err, "store sync channel already configured for this connection") + m := createTestStoreManager() + id := uuid.New() + con := &MockBridgeConnection{Id: &id} + + subId := uuid.New() + s := &MockBridgeSubscription{ + Id: &subId, + } + syncChannelDst := "/topic-prefix/transport-store-sync." + id.String() + con.On("Subscribe", syncChannelDst).Return(s, nil) + con.On("SendMessage", syncChannelDst, mock.Anything).Return(nil) + m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") + + storeManagerImpl := m.(*storeManager) + + conf, ok := storeManagerImpl.syncChannels[id] + assert.True(t, ok) + assert.NotNil(t, conf.syncChannelName) + assert.Equal(t, conf.pubPrefix, "/pub-prefix/") + assert.Equal(t, conf.topicPrefix, "/topic-prefix/") + assert.Equal(t, conf.conn, con) + + syncCh, _ := storeManagerImpl.eventBus.GetChannelManager().GetChannel(conf.syncChannelName) + assert.True(t, syncCh.galactic) + + err := m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") + assert.EqualError(t, err, "store sync channel already configured for this connection") } func TestStoreManager_OpenGalacticStore(t *testing.T) { - m := createTestStoreManager() - id := uuid.New() - con := &MockBridgeConnection{Id: &id} + m := createTestStoreManager() + id := uuid.New() + con := &MockBridgeConnection{Id: &id} - var s BusStore - var err error + var s BusStore + var err error - s, err = m.OpenGalacticStore("galacticStore", con) + s, err = m.OpenGalacticStore("galacticStore", con) - assert.Nil(t, s) - assert.EqualError(t, err, "sync channel is not configured for this connection") + assert.Nil(t, s) + assert.EqualError(t, err, "sync channel is not configured for this connection") - subId := uuid.New() - sub := &MockBridgeSubscription{ - Id: &subId, - } - con.On("Subscribe", mock.Anything).Return(sub, nil) - con.On("SendJSONMessage", mock.Anything, mock.Anything).Return(nil) - con.On("SendMessage", mock.Anything, mock.Anything, mock.Anything).Return(nil) - m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") + subId := uuid.New() + sub := &MockBridgeSubscription{ + Id: &subId, + } + con.On("Subscribe", mock.Anything).Return(sub, nil) + con.On("SendJSONMessage", mock.Anything, mock.Anything).Return(nil) + con.On("SendMessage", mock.Anything, mock.Anything, mock.Anything).Return(nil) + m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") - storeManagerImpl := m.(*storeManager) + storeManagerImpl := m.(*storeManager) - conf, _ := storeManagerImpl.syncChannels[id] - assert.NotNil(t, conf) + conf, _ := storeManagerImpl.syncChannels[id] + assert.NotNil(t, conf) - s, err = m.OpenGalacticStore("galacticStore", con) + s, err = m.OpenGalacticStore("galacticStore", con) - assert.Nil(t, err) - assert.True(t, s.IsGalactic()) + assert.Nil(t, err) + assert.True(t, s.IsGalactic()) - wg := sync.WaitGroup{} - wg.Add(1) + wg := sync.WaitGroup{} + wg.Add(1) - s.WhenReady(func() { - assert.Equal(t, len(s.AllValues()), 2) - assert.Equal(t, s.GetValue("id1"), "value1") - assert.Equal(t, s.GetValue("id2"), "value2") + s.WhenReady(func() { + assert.Equal(t, len(s.AllValues()), 2) + assert.Equal(t, s.GetValue("id1"), "value1") + assert.Equal(t, s.GetValue("id2"), "value2") - wg.Done() - }) + wg.Done() + }) - var jsonBlob = []byte(`{ + var jsonBlob = []byte(`{ "storeId": "galacticStore", "responseType": "storeContentResponse", "items": { @@ -156,58 +156,58 @@ func TestStoreManager_OpenGalacticStore(t *testing.T) { } }`) - storeManagerImpl.eventBus.SendResponseMessage(conf.syncChannelName, jsonBlob, nil) - wg.Wait() + storeManagerImpl.eventBus.SendResponseMessage(conf.syncChannelName, jsonBlob, nil) + wg.Wait() } type MockStoreItem struct { - From string `json:"from"` - Message string `json:"message"` + From string `json:"from"` + Message string `json:"message"` } func TestStoreManager_OpenGalacticStoreWithType(t *testing.T) { - m := createTestStoreManager() - id := uuid.New() - con := &MockBridgeConnection{Id: &id} + m := createTestStoreManager() + id := uuid.New() + con := &MockBridgeConnection{Id: &id} - subId := uuid.New() - sub := &MockBridgeSubscription{ - Id: &subId, - } - con.On("Subscribe", mock.Anything).Return(sub, nil) - con.On("SendJSONMessage", mock.Anything, mock.Anything).Return(nil) - con.On("SendMessage", mock.Anything, mock.Anything).Return(nil) - m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") + subId := uuid.New() + sub := &MockBridgeSubscription{ + Id: &subId, + } + con.On("Subscribe", mock.Anything).Return(sub, nil) + con.On("SendJSONMessage", mock.Anything, mock.Anything).Return(nil) + con.On("SendMessage", mock.Anything, mock.Anything).Return(nil) + m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") - storeManagerImpl := m.(*storeManager) + storeManagerImpl := m.(*storeManager) - conf, _ := storeManagerImpl.syncChannels[id] - assert.NotNil(t, conf) + conf, _ := storeManagerImpl.syncChannels[id] + assert.NotNil(t, conf) - store,_ := m.OpenGalacticStoreWithItemType("galacticStore", con, reflect.TypeOf(MockStoreItem{})) + store, _ := m.OpenGalacticStoreWithItemType("galacticStore", con, reflect.TypeOf(MockStoreItem{})) - store2 ,err := m.OpenGalacticStoreWithItemType("galacticStore", con, reflect.TypeOf(MockStoreItem{})) + store2, err := m.OpenGalacticStoreWithItemType("galacticStore", con, reflect.TypeOf(MockStoreItem{})) - assert.Equal(t, store, store2) - assert.Nil(t, err) + assert.Equal(t, store, store2) + assert.Nil(t, err) - assert.True(t, store.IsGalactic()) + assert.True(t, store.IsGalactic()) - wg := sync.WaitGroup{} - wg.Add(1) + wg := sync.WaitGroup{} + wg.Add(1) - store.WhenReady(func() { + store.WhenReady(func() { - assert.Equal(t, len(store.AllValues()), 1) - assert.Equal(t, store.GetValue("id1"), MockStoreItem{ - From: "test-user", - Message: "test-message", - }) + assert.Equal(t, len(store.AllValues()), 1) + assert.Equal(t, store.GetValue("id1"), MockStoreItem{ + From: "test-user", + Message: "test-message", + }) - wg.Done() - }) + wg.Done() + }) - var jsonBlob = []byte(`{ + var jsonBlob = []byte(`{ "storeId": "galacticStore", "responseType": "storeContentResponse", "items": { @@ -218,28 +218,27 @@ func TestStoreManager_OpenGalacticStoreWithType(t *testing.T) { } }`) - storeManagerImpl.eventBus.SendResponseMessage(conf.syncChannelName, jsonBlob, nil) - wg.Wait() + storeManagerImpl.eventBus.SendResponseMessage(conf.syncChannelName, jsonBlob, nil) + wg.Wait() } func TestStoreManager_OpenGalacticStoreWithLocalStoreId(t *testing.T) { - m := createTestStoreManager() - id := uuid.New() - con := &MockBridgeConnection{Id: &id} - subId := uuid.New() - sub := &MockBridgeSubscription{ - Id: &subId, - } - con.On("Subscribe", mock.Anything).Return(sub, nil) - con.On("SendMessage", mock.Anything, mock.Anything).Return(nil) - m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") - + m := createTestStoreManager() + id := uuid.New() + con := &MockBridgeConnection{Id: &id} + subId := uuid.New() + sub := &MockBridgeSubscription{ + Id: &subId, + } + con.On("Subscribe", mock.Anything).Return(sub, nil) + con.On("SendMessage", mock.Anything, mock.Anything).Return(nil) + m.ConfigureStoreSyncChannel(con, "/topic-prefix", "/pub-prefix") - localStore := m.CreateStore("localStore") + localStore := m.CreateStore("localStore") - store, err := m.OpenGalacticStoreWithItemType("localStore", con, reflect.TypeOf(MockStoreItem{})) + store, err := m.OpenGalacticStoreWithItemType("localStore", con, reflect.TypeOf(MockStoreItem{})) - assert.EqualError(t, err, "cannot open galactic store: there is a local store with the same name") - assert.Equal(t, store, localStore) + assert.EqualError(t, err, "cannot open galactic store: there is a local store with the same name") + assert.Equal(t, store, localStore) } diff --git a/bus/store_stream.go b/bus/store_stream.go index a48e089..a10e569 100644 --- a/bus/store_stream.go +++ b/bus/store_stream.go @@ -4,94 +4,94 @@ package bus import ( - "fmt" - "sync" + "fmt" + "sync" ) type StoreChangeHandlerFunction func(change *StoreChange) // Interface for subscribing for store changes type StoreStream interface { - // Subscribe to the store changes stream. - Subscribe(handler StoreChangeHandlerFunction) error - // Unsubscribe from the stream. - Unsubscribe() error + // Subscribe to the store changes stream. + Subscribe(handler StoreChangeHandlerFunction) error + // Unsubscribe from the stream. + Unsubscribe() error } type streamFilter struct { - states []interface{} - itemId string - matchAllItems bool + states []interface{} + itemId string + matchAllItems bool } func (f *streamFilter) match(change *StoreChange) bool { - if f.matchAllItems || f.itemId == change.Id { - if len(f.states) == 0 { - return true - } - - for _, s := range f.states { - if s == change.State { - return true - } - } - - } - return false + if f.matchAllItems || f.itemId == change.Id { + if len(f.states) == 0 { + return true + } + + for _, s := range f.states { + if s == change.State { + return true + } + } + + } + return false } type storeStream struct { - handler StoreChangeHandlerFunction - lock sync.RWMutex - store *busStore - filter *streamFilter + handler StoreChangeHandlerFunction + lock sync.RWMutex + store *busStore + filter *streamFilter } func newStoreStream(store *busStore, filter *streamFilter) *storeStream { - stream := new(storeStream) - stream.store = store - stream.filter = filter - return stream + stream := new(storeStream) + stream.store = store + stream.filter = filter + return stream } func (s *storeStream) Subscribe(handler StoreChangeHandlerFunction) error { - if handler == nil { - return fmt.Errorf("invalid StoreChangeHandlerFunction") - } - - s.lock.Lock() - if s.handler != nil { - s.lock.Unlock() - return fmt.Errorf("stream already subscribed") - } - s.handler = handler - s.lock.Unlock() - - s.store.onStreamSubscribe(s) - return nil + if handler == nil { + return fmt.Errorf("invalid StoreChangeHandlerFunction") + } + + s.lock.Lock() + if s.handler != nil { + s.lock.Unlock() + return fmt.Errorf("stream already subscribed") + } + s.handler = handler + s.lock.Unlock() + + s.store.onStreamSubscribe(s) + return nil } func (s *storeStream) Unsubscribe() error { - s.lock.Lock() - if s.handler == nil { - s.lock.Unlock() - return fmt.Errorf("stream not subscribed") - } - s.handler = nil - s.lock.Unlock() - - s.store.onStreamUnsubscribe(s) - return nil + s.lock.Lock() + if s.handler == nil { + s.lock.Unlock() + return fmt.Errorf("stream not subscribed") + } + s.handler = nil + s.lock.Unlock() + + s.store.onStreamUnsubscribe(s) + return nil } func (s *storeStream) onStoreChange(change *StoreChange) { - if !s.filter.match(change) { - return - } - - s.lock.RLock() - defer s.lock.RUnlock() - if s.handler != nil { - go s.handler(change) - } + if !s.filter.match(change) { + return + } + + s.lock.RLock() + defer s.lock.RUnlock() + if s.handler != nil { + go s.handler(change) + } } diff --git a/bus/store_sync_service.go b/bus/store_sync_service.go index fe6d2b5..5e10f88 100644 --- a/bus/store_sync_service.go +++ b/bus/store_sync_service.go @@ -4,296 +4,295 @@ package bus import ( - "github.com/google/uuid" - "github.com/vmware/transport-go/model" - "strings" - "sync" + "github.com/google/uuid" + "github.com/vmware/transport-go/model" + "strings" + "sync" ) const ( - openStoreRequest = "openStore" - updateStoreRequest = "updateStore" - closeStoreRequest = "closeStore" - galacticStoreSyncUpdate = "galacticStoreSyncUpdate" - galacticStoreSyncRemove = "galacticStoreSyncRemove" + openStoreRequest = "openStore" + updateStoreRequest = "updateStore" + closeStoreRequest = "closeStore" + galacticStoreSyncUpdate = "galacticStoreSyncUpdate" + galacticStoreSyncRemove = "galacticStoreSyncRemove" ) type storeSyncService struct { - bus EventBus - lock sync.Mutex - syncClients map[string]*syncClientChannel - syncStoreListeners map[string]*syncStoreListener + bus EventBus + lock sync.Mutex + syncClients map[string]*syncClientChannel + syncStoreListeners map[string]*syncStoreListener } type syncStoreListener struct { - storeStream StoreStream - clientSyncChannels map[string]bool - lock sync.RWMutex + storeStream StoreStream + clientSyncChannels map[string]bool + lock sync.RWMutex } - type syncClientChannel struct { - channelName string - clientRequestListener MessageHandler - openStores map[string]bool + channelName string + clientRequestListener MessageHandler + openStores map[string]bool } func newStoreSyncService(bus EventBus) *storeSyncService { - syncService := &storeSyncService{ - bus: bus, - syncClients: make(map[string]*syncClientChannel), - syncStoreListeners: make(map[string]*syncStoreListener), - } - syncService.init() - return syncService + syncService := &storeSyncService{ + bus: bus, + syncClients: make(map[string]*syncClientChannel), + syncStoreListeners: make(map[string]*syncStoreListener), + } + syncService.init() + return syncService } func (syncService *storeSyncService) init() { - syncService.bus.AddMonitorEventListener( - func(monitorEvt *MonitorEvent) { - if !strings.HasPrefix(monitorEvt.EntityName, "transport-store-sync.") { - // not a store sync channel, ignore the message - return - } - - switch monitorEvt.EventType { - case FabricEndpointSubscribeEvt: - syncService.openNewClientSyncChannel(monitorEvt.EntityName) - case ChannelDestroyedEvt: - syncService.closeClientSyncChannel(monitorEvt.EntityName) - } - }, - FabricEndpointSubscribeEvt, ChannelDestroyedEvt) + syncService.bus.AddMonitorEventListener( + func(monitorEvt *MonitorEvent) { + if !strings.HasPrefix(monitorEvt.EntityName, "transport-store-sync.") { + // not a store sync channel, ignore the message + return + } + + switch monitorEvt.EventType { + case FabricEndpointSubscribeEvt: + syncService.openNewClientSyncChannel(monitorEvt.EntityName) + case ChannelDestroyedEvt: + syncService.closeClientSyncChannel(monitorEvt.EntityName) + } + }, + FabricEndpointSubscribeEvt, ChannelDestroyedEvt) } func (syncService *storeSyncService) openNewClientSyncChannel(channelName string) { - syncService.lock.Lock() - defer syncService.lock.Unlock() - - if _, ok := syncService.syncClients[channelName]; ok { - // channel already opened. - return - } - - syncClient := &syncClientChannel{ - channelName: channelName, - openStores: make(map[string]bool), - } - syncClient.clientRequestListener, _ = syncService.bus.ListenRequestStream(channelName) - if syncClient.clientRequestListener != nil { - syncClient.clientRequestListener.Handle( - func(message *model.Message) { - request, reqOk := message.Payload.(*model.Request) - if !reqOk || request.Payload == nil { - return - } - var storeRequest map[string]interface{} - storeRequest, ok := request.Payload.(map[string]interface{}) - if !ok { - return - } - - switch request.Request { - case openStoreRequest: - syncService.openStore(syncClient, storeRequest, request.Id) - case closeStoreRequest: - syncService.closeStore(syncClient, storeRequest, request.Id) - case updateStoreRequest: - syncService.updateStore(syncClient, storeRequest, request.Id) - } - }, func(e error) {}) - } - syncService.syncClients[channelName] = syncClient + syncService.lock.Lock() + defer syncService.lock.Unlock() + + if _, ok := syncService.syncClients[channelName]; ok { + // channel already opened. + return + } + + syncClient := &syncClientChannel{ + channelName: channelName, + openStores: make(map[string]bool), + } + syncClient.clientRequestListener, _ = syncService.bus.ListenRequestStream(channelName) + if syncClient.clientRequestListener != nil { + syncClient.clientRequestListener.Handle( + func(message *model.Message) { + request, reqOk := message.Payload.(*model.Request) + if !reqOk || request.Payload == nil { + return + } + var storeRequest map[string]interface{} + storeRequest, ok := request.Payload.(map[string]interface{}) + if !ok { + return + } + + switch request.Request { + case openStoreRequest: + syncService.openStore(syncClient, storeRequest, request.Id) + case closeStoreRequest: + syncService.closeStore(syncClient, storeRequest, request.Id) + case updateStoreRequest: + syncService.updateStore(syncClient, storeRequest, request.Id) + } + }, func(e error) {}) + } + syncService.syncClients[channelName] = syncClient } func (syncService *storeSyncService) closeClientSyncChannel(channelName string) { - syncService.lock.Lock() - defer syncService.lock.Unlock() - - syncClient, ok := syncService.syncClients[channelName] - if !ok || syncClient == nil { - // client is already closed - return - } - - for storeId := range syncClient.openStores { - listener := syncService.syncStoreListeners[storeId] - if listener != nil { - listener.removeChannel(channelName) - if listener.isEmpty() { - listener.unsubscribe() - delete(syncService.syncStoreListeners, storeId) - } - } - } - - delete(syncService.syncClients, channelName) + syncService.lock.Lock() + defer syncService.lock.Unlock() + + syncClient, ok := syncService.syncClients[channelName] + if !ok || syncClient == nil { + // client is already closed + return + } + + for storeId := range syncClient.openStores { + listener := syncService.syncStoreListeners[storeId] + if listener != nil { + listener.removeChannel(channelName) + if listener.isEmpty() { + listener.unsubscribe() + delete(syncService.syncStoreListeners, storeId) + } + } + } + + delete(syncService.syncClients, channelName) } func (syncService *storeSyncService) openStore( - syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { - - storeId, ok := getStingProperty("storeId", request) - if !ok || storeId == "" { - syncService.sendErrorResponse(syncClient.channelName, "Invalid OpenStoreRequest", reqId) - return - } - - store := syncService.bus.GetStoreManager().GetStore(storeId) - if store == nil { - syncService.sendErrorResponse( - syncClient.channelName, "Cannot open non-existing store: " + storeId, reqId) - return - } - - syncService.lock.Lock() - defer syncService.lock.Unlock() - - syncClient.openStores[storeId] = true - - storeListener, ok := syncService.syncStoreListeners[storeId] - if !ok { - storeListener = newSyncStoreListener(syncService.bus, store) - syncService.syncStoreListeners[storeId] = storeListener - } - storeListener.addChannel(syncClient.channelName) - - store.WhenReady(func() { - items, version := store.AllValuesAndVersion() - - syncService.bus.SendResponseMessage(syncClient.channelName, - model.NewStoreContentResponse(storeId, items, version), nil) - }) + syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { + + storeId, ok := getStingProperty("storeId", request) + if !ok || storeId == "" { + syncService.sendErrorResponse(syncClient.channelName, "Invalid OpenStoreRequest", reqId) + return + } + + store := syncService.bus.GetStoreManager().GetStore(storeId) + if store == nil { + syncService.sendErrorResponse( + syncClient.channelName, "Cannot open non-existing store: "+storeId, reqId) + return + } + + syncService.lock.Lock() + defer syncService.lock.Unlock() + + syncClient.openStores[storeId] = true + + storeListener, ok := syncService.syncStoreListeners[storeId] + if !ok { + storeListener = newSyncStoreListener(syncService.bus, store) + syncService.syncStoreListeners[storeId] = storeListener + } + storeListener.addChannel(syncClient.channelName) + + store.WhenReady(func() { + items, version := store.AllValuesAndVersion() + + syncService.bus.SendResponseMessage(syncClient.channelName, + model.NewStoreContentResponse(storeId, items, version), nil) + }) } func (syncService *storeSyncService) closeStore( - syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { - - storeId, ok := getStingProperty("storeId", request) - if !ok || storeId == "" { - syncService.sendErrorResponse(syncClient.channelName, "Invalid CloseStoreRequest", reqId) - return - } - - syncService.lock.Lock() - defer syncService.lock.Unlock() - - delete(syncClient.openStores, storeId) - - storeListener, ok := syncService.syncStoreListeners[storeId] - if ok && storeListener != nil { - storeListener.removeChannel(syncClient.channelName) - if storeListener.isEmpty() { - storeListener.unsubscribe() - delete(syncService.syncStoreListeners, storeId) - } - } + syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { + + storeId, ok := getStingProperty("storeId", request) + if !ok || storeId == "" { + syncService.sendErrorResponse(syncClient.channelName, "Invalid CloseStoreRequest", reqId) + return + } + + syncService.lock.Lock() + defer syncService.lock.Unlock() + + delete(syncClient.openStores, storeId) + + storeListener, ok := syncService.syncStoreListeners[storeId] + if ok && storeListener != nil { + storeListener.removeChannel(syncClient.channelName) + if storeListener.isEmpty() { + storeListener.unsubscribe() + delete(syncService.syncStoreListeners, storeId) + } + } } func (syncService *storeSyncService) updateStore( - syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { - - storeId, ok := getStingProperty("storeId", request) - if !ok || storeId == "" { - syncService.sendErrorResponse( - syncClient.channelName, "Invalid UpdateStoreRequest: missing storeId", reqId) - return - } - itemId, ok := getStingProperty("itemId", request) - if !ok || itemId == "" { - syncService.sendErrorResponse( - syncClient.channelName, "Invalid UpdateStoreRequest: missing itemId", reqId) - return - } - - store := syncService.bus.GetStoreManager().GetStore(storeId) - if store == nil { - syncService.sendErrorResponse( - syncClient.channelName, "Cannot update non-existing store: " + storeId, reqId) - return - } - - rawValue, ok := request["newItemValue"] - if rawValue == nil { - store.Remove(itemId, galacticStoreSyncRemove) - } else { - deserializedValue, err := model.ConvertValueToType(rawValue, store.GetItemType()) - if err != nil || deserializedValue == nil { - errMsg := "Cannot deserialize UpdateStoreRequest item value" - if err != nil { - errMsg = "Cannot deserialize UpdateStoreRequest item value: " + err.Error() - } - syncService.sendErrorResponse(syncClient.channelName, errMsg, reqId) - return - } - store.Put(itemId, deserializedValue, galacticStoreSyncUpdate) - } + syncClient *syncClientChannel, request map[string]interface{}, reqId *uuid.UUID) { + + storeId, ok := getStingProperty("storeId", request) + if !ok || storeId == "" { + syncService.sendErrorResponse( + syncClient.channelName, "Invalid UpdateStoreRequest: missing storeId", reqId) + return + } + itemId, ok := getStingProperty("itemId", request) + if !ok || itemId == "" { + syncService.sendErrorResponse( + syncClient.channelName, "Invalid UpdateStoreRequest: missing itemId", reqId) + return + } + + store := syncService.bus.GetStoreManager().GetStore(storeId) + if store == nil { + syncService.sendErrorResponse( + syncClient.channelName, "Cannot update non-existing store: "+storeId, reqId) + return + } + + rawValue, ok := request["newItemValue"] + if rawValue == nil { + store.Remove(itemId, galacticStoreSyncRemove) + } else { + deserializedValue, err := model.ConvertValueToType(rawValue, store.GetItemType()) + if err != nil || deserializedValue == nil { + errMsg := "Cannot deserialize UpdateStoreRequest item value" + if err != nil { + errMsg = "Cannot deserialize UpdateStoreRequest item value: " + err.Error() + } + syncService.sendErrorResponse(syncClient.channelName, errMsg, reqId) + return + } + store.Put(itemId, deserializedValue, galacticStoreSyncUpdate) + } } func getStingProperty(id string, request map[string]interface{}) (string, bool) { - propValue, ok := request[id] - if !ok || propValue == nil { - return "", false - } - stringValue, ok := propValue.(string) - return stringValue, ok + propValue, ok := request[id] + if !ok || propValue == nil { + return "", false + } + stringValue, ok := propValue.(string) + return stringValue, ok } func (syncService *storeSyncService) sendErrorResponse( - clientChannel string, errorMsg string, reqId *uuid.UUID) { - - syncService.bus.SendResponseMessage(clientChannel, &model.Response{ - Id: reqId, - Error: true, - ErrorCode: 1, - ErrorMessage: errorMsg, - }, nil) + clientChannel string, errorMsg string, reqId *uuid.UUID) { + + syncService.bus.SendResponseMessage(clientChannel, &model.Response{ + Id: reqId, + Error: true, + ErrorCode: 1, + ErrorMessage: errorMsg, + }, nil) } func newSyncStoreListener(bus EventBus, store BusStore) *syncStoreListener { - listener := &syncStoreListener{ - storeStream: store.OnAllChanges(), - clientSyncChannels: make(map[string]bool), - } + listener := &syncStoreListener{ + storeStream: store.OnAllChanges(), + clientSyncChannels: make(map[string]bool), + } - listener.storeStream.Subscribe(func(change *StoreChange) { - updateStoreResp := model.NewUpdateStoreResponse( - store.GetName(), change.Id, change.Value, change.StoreVersion) - if change.IsDeleteChange { - updateStoreResp.NewItemValue = nil - } + listener.storeStream.Subscribe(func(change *StoreChange) { + updateStoreResp := model.NewUpdateStoreResponse( + store.GetName(), change.Id, change.Value, change.StoreVersion) + if change.IsDeleteChange { + updateStoreResp.NewItemValue = nil + } - listener.lock.RLock() - defer listener.lock.RUnlock() + listener.lock.RLock() + defer listener.lock.RUnlock() - for chName := range listener.clientSyncChannels { - bus.SendResponseMessage(chName, updateStoreResp, nil) - } - }) + for chName := range listener.clientSyncChannels { + bus.SendResponseMessage(chName, updateStoreResp, nil) + } + }) - return listener + return listener } func (l *syncStoreListener) unsubscribe() { - l.storeStream.Unsubscribe() + l.storeStream.Unsubscribe() } func (l *syncStoreListener) addChannel(clientChannel string) { - l.lock.Lock() - defer l.lock.Unlock() - l.clientSyncChannels[clientChannel] = true + l.lock.Lock() + defer l.lock.Unlock() + l.clientSyncChannels[clientChannel] = true } func (l *syncStoreListener) removeChannel(clientChannel string) { - l.lock.Lock() - defer l.lock.Unlock() - delete(l.clientSyncChannels, clientChannel) + l.lock.Lock() + defer l.lock.Unlock() + delete(l.clientSyncChannels, clientChannel) } func (l *syncStoreListener) isEmpty() bool { - l.lock.Lock() - defer l.lock.Unlock() - return len(l.clientSyncChannels) == 0 + l.lock.Lock() + defer l.lock.Unlock() + return len(l.clientSyncChannels) == 0 } diff --git a/bus/store_sync_service_test.go b/bus/store_sync_service_test.go index ad3b9d3..40e5dda 100644 --- a/bus/store_sync_service_test.go +++ b/bus/store_sync_service_test.go @@ -4,511 +4,511 @@ package bus import ( - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/vmware/transport-go/model" - "reflect" - "strings" - "sync" - "testing" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/vmware/transport-go/model" + "reflect" + "strings" + "sync" + "testing" ) func testStoreSyncService() (*storeSyncService, EventBus) { - bus := newTestEventBus() - return newStoreSyncService(bus), bus + bus := newTestEventBus() + return newStoreSyncService(bus), bus } func TestStoreSyncService_NewConnection(t *testing.T) { - service, bus := testStoreSyncService() + service, bus := testStoreSyncService() - // verify that the service ignores non transport-store-sync events - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, "galactic-channel", nil) - assert.Equal(t, len(service.syncClients), 0) + // verify that the service ignores non transport-store-sync events + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, "galactic-channel", nil) + assert.Equal(t, len(service.syncClients), 0) - syncChan := "transport-store-sync.1" + syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) + bus.GetChannelManager().CreateChannel(syncChan) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - assert.Equal(t, len(service.syncClients), 1) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + assert.Equal(t, len(service.syncClients), 1) } func TestStoreSyncService_OpenStoreErrors(t *testing.T) { - _, bus := testStoreSyncService() - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - - mh, _ := bus.ListenStream(syncChan) - wg := sync.WaitGroup{} - var errors []*model.Response - mh.Handle(func(message *model.Message) { - errors = append(errors, message.Payload.(*model.Response)) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - id := uuid.New() - bus.SendRequestMessage(syncChan, "invalid-request", nil) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: "invalid-payload", - Id: &id, - }, nil) - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: make(map[string]interface{}), - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, errors[0].Id, &id) - assert.True(t, errors[0].Error) - assert.Equal(t, errors[0].ErrorMessage, "Invalid OpenStoreRequest") - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{} { "storeId": "non-existing-store" }, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, errors[1].Id, &id) - assert.True(t, errors[1].Error) - assert.Equal(t, errors[1].ErrorMessage, "Cannot open non-existing store: non-existing-store") + _, bus := testStoreSyncService() + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + + mh, _ := bus.ListenStream(syncChan) + wg := sync.WaitGroup{} + var errors []*model.Response + mh.Handle(func(message *model.Message) { + errors = append(errors, message.Payload.(*model.Response)) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + id := uuid.New() + bus.SendRequestMessage(syncChan, "invalid-request", nil) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: "invalid-payload", + Id: &id, + }, nil) + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: make(map[string]interface{}), + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, errors[0].Id, &id) + assert.True(t, errors[0].Error) + assert.Equal(t, errors[0].ErrorMessage, "Invalid OpenStoreRequest") + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "non-existing-store"}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, errors[1].Id, &id) + assert.True(t, errors[1].Error) + assert.Equal(t, errors[1].ErrorMessage, "Cannot open non-existing store: non-existing-store") } func TestStoreSyncService_OpenStore(t *testing.T) { - service, bus := testStoreSyncService() - - store := bus.GetStoreManager().CreateStoreWithType( - "test-store", reflect.TypeOf(&MockStoreItem{})) - store.Populate(map[string]interface{} { - "item1": &MockStoreItem{From:"test", Message:"test-message"}, - "item2": &MockStoreItem{From:"test2", Message: uuid.New().String()}, - }) - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - - wg := sync.WaitGroup{} - var syncResp [] interface{} - - mh, _ := bus.ListenStream(syncChan) - mh.Handle(func(message *model.Message) { - syncResp = append(syncResp, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store" }, - }, nil) - wg.Wait() - - assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) - assert.Equal(t, len(service.syncStoreListeners), 1) - assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan], true) - - resp := syncResp[0].(*model.StoreContentResponse) - - assert.Equal(t, resp.StoreId, "test-store") - items, version := store.AllValuesAndVersion() - - assert.Equal(t, resp.StoreVersion, version) - assert.Equal(t, resp.Items, items) - assert.Equal(t, resp.ResponseType, "storeContentResponse") - - // try subscribing to the same sync channel again - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store" }, - }, nil) - wg.Wait() - - assert.Equal(t, len(syncResp), 2) - assert.Equal(t, syncResp[1].(*model.StoreContentResponse).ResponseType, "storeContentResponse") - - syncChan2 := "transport-store-sync.2" - bus.GetChannelManager().CreateChannel(syncChan2) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) - - mh2, _ := bus.ListenStream(syncChan2) - mh2.Handle(func(message *model.Message) { - syncResp = append(syncResp, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - wg.Add(1) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store" }, - }, nil) - wg.Wait() - - assert.Equal(t, len(syncResp), 3) - assert.Equal(t, syncResp[2].(*model.StoreContentResponse).ResponseType, "storeContentResponse") - - assert.Equal(t, len(service.syncClients), 2) - assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) - assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) - assert.Equal(t, service.syncClients[syncChan2].openStores["test-store"], true) - - assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 2) - assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) - - bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan, nil) - - assert.Equal(t, len(service.syncClients), 1) - assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) - assert.Equal(t, service.syncClients[syncChan2].openStores["test-store"], true) - assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 1) - assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) - - bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan2, nil) - - assert.Equal(t, len(service.syncClients), 0) - assert.Equal(t, len(service.syncStoreListeners), 0) - - // try closing the syncChan2 again - bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan2, nil) + service, bus := testStoreSyncService() + + store := bus.GetStoreManager().CreateStoreWithType( + "test-store", reflect.TypeOf(&MockStoreItem{})) + store.Populate(map[string]interface{}{ + "item1": &MockStoreItem{From: "test", Message: "test-message"}, + "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, + }) + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + + wg := sync.WaitGroup{} + var syncResp []interface{} + + mh, _ := bus.ListenStream(syncChan) + mh.Handle(func(message *model.Message) { + syncResp = append(syncResp, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) + assert.Equal(t, len(service.syncStoreListeners), 1) + assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan], true) + + resp := syncResp[0].(*model.StoreContentResponse) + + assert.Equal(t, resp.StoreId, "test-store") + items, version := store.AllValuesAndVersion() + + assert.Equal(t, resp.StoreVersion, version) + assert.Equal(t, resp.Items, items) + assert.Equal(t, resp.ResponseType, "storeContentResponse") + + // try subscribing to the same sync channel again + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(syncResp), 2) + assert.Equal(t, syncResp[1].(*model.StoreContentResponse).ResponseType, "storeContentResponse") + + syncChan2 := "transport-store-sync.2" + bus.GetChannelManager().CreateChannel(syncChan2) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) + + mh2, _ := bus.ListenStream(syncChan2) + mh2.Handle(func(message *model.Message) { + syncResp = append(syncResp, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + wg.Add(1) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(syncResp), 3) + assert.Equal(t, syncResp[2].(*model.StoreContentResponse).ResponseType, "storeContentResponse") + + assert.Equal(t, len(service.syncClients), 2) + assert.Equal(t, len(service.syncClients[syncChan].openStores), 1) + assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) + assert.Equal(t, service.syncClients[syncChan2].openStores["test-store"], true) + + assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 2) + assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) + + bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan, nil) + + assert.Equal(t, len(service.syncClients), 1) + assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) + assert.Equal(t, service.syncClients[syncChan2].openStores["test-store"], true) + assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 1) + assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) + + bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan2, nil) + + assert.Equal(t, len(service.syncClients), 0) + assert.Equal(t, len(service.syncStoreListeners), 0) + + // try closing the syncChan2 again + bus.SendMonitorEvent(ChannelDestroyedEvt, syncChan2, nil) } func TestStoreSyncService_CloseStore(t *testing.T) { - service, bus := testStoreSyncService() - - store := bus.GetStoreManager().CreateStoreWithType( - "test-store", reflect.TypeOf(&MockStoreItem{})) - store.Populate(map[string]interface{}{ - "item1": &MockStoreItem{From: "test", Message: "test-message"}, - "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, - }) - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - - syncChan2 := "transport-store-sync.2" - bus.GetChannelManager().CreateChannel(syncChan2) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) - - wg := sync.WaitGroup{} - var syncResp1 [] interface{} - - mh, _ := bus.ListenStream(syncChan) - mh.Handle(func(message *model.Message) { - syncResp1 = append(syncResp1, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - mh2, _ := bus.ListenStream(syncChan2) - mh2.Handle(func(message *model.Message) { - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - id := uuid.New() - - wg.Add(2) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store" }, - }, nil) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store" }, - }, nil) - wg.Wait() - - assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 2) - - bus.SendRequestMessage(syncChan, &model.Request{ - Request: closeStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store"}, - Id: &id, - }, nil) - - wg.Add(2) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: closeStoreRequest, - Payload: make(map[string]interface{}), - Id: &id, - }, nil) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: closeStoreRequest, - Payload: map[string]interface{} {"storeId": ""}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp1[1].(*model.Response).ErrorMessage, "Invalid CloseStoreRequest") - assert.Equal(t, syncResp1[1].(*model.Response).Id, &id) - assert.Equal(t, syncResp1[1].(*model.Response).Error, true) - - service.lock.Lock() - assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 1) - assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) - assert.Equal(t, len(service.syncClients[syncChan].openStores), 0) - assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) - service.lock.Unlock() - - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: closeStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store"}, - Id: &id, - }, nil) - - wg.Add(2) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: closeStoreRequest, - Payload: make(map[string]interface{}), - Id: &id, - }, nil) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: closeStoreRequest, - Payload: map[string]interface{} {"storeId": ""}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp1[2].(*model.Response).ErrorMessage, "Invalid CloseStoreRequest") - assert.Equal(t, syncResp1[2].(*model.Response).Id, &id) - assert.Equal(t, syncResp1[2].(*model.Response).Error, true) - - service.lock.Lock() - assert.Equal(t, len(service.syncStoreListeners), 0) - assert.Equal(t, len(service.syncClients[syncChan].openStores), 0) - assert.Equal(t, len(service.syncClients[syncChan2].openStores), 0) - service.lock.Unlock() + service, bus := testStoreSyncService() + + store := bus.GetStoreManager().CreateStoreWithType( + "test-store", reflect.TypeOf(&MockStoreItem{})) + store.Populate(map[string]interface{}{ + "item1": &MockStoreItem{From: "test", Message: "test-message"}, + "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, + }) + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + + syncChan2 := "transport-store-sync.2" + bus.GetChannelManager().CreateChannel(syncChan2) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) + + wg := sync.WaitGroup{} + var syncResp1 []interface{} + + mh, _ := bus.ListenStream(syncChan) + mh.Handle(func(message *model.Message) { + syncResp1 = append(syncResp1, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + mh2, _ := bus.ListenStream(syncChan2) + mh2.Handle(func(message *model.Message) { + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + id := uuid.New() + + wg.Add(2) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 2) + + bus.SendRequestMessage(syncChan, &model.Request{ + Request: closeStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + Id: &id, + }, nil) + + wg.Add(2) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: closeStoreRequest, + Payload: make(map[string]interface{}), + Id: &id, + }, nil) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: closeStoreRequest, + Payload: map[string]interface{}{"storeId": ""}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp1[1].(*model.Response).ErrorMessage, "Invalid CloseStoreRequest") + assert.Equal(t, syncResp1[1].(*model.Response).Id, &id) + assert.Equal(t, syncResp1[1].(*model.Response).Error, true) + + service.lock.Lock() + assert.Equal(t, len(service.syncStoreListeners["test-store"].clientSyncChannels), 1) + assert.Equal(t, service.syncStoreListeners["test-store"].clientSyncChannels[syncChan2], true) + assert.Equal(t, len(service.syncClients[syncChan].openStores), 0) + assert.Equal(t, len(service.syncClients[syncChan2].openStores), 1) + service.lock.Unlock() + + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: closeStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + Id: &id, + }, nil) + + wg.Add(2) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: closeStoreRequest, + Payload: make(map[string]interface{}), + Id: &id, + }, nil) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: closeStoreRequest, + Payload: map[string]interface{}{"storeId": ""}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp1[2].(*model.Response).ErrorMessage, "Invalid CloseStoreRequest") + assert.Equal(t, syncResp1[2].(*model.Response).Id, &id) + assert.Equal(t, syncResp1[2].(*model.Response).Error, true) + + service.lock.Lock() + assert.Equal(t, len(service.syncStoreListeners), 0) + assert.Equal(t, len(service.syncClients[syncChan].openStores), 0) + assert.Equal(t, len(service.syncClients[syncChan2].openStores), 0) + service.lock.Unlock() } func TestStoreSyncService_UpdateStoreErrors(t *testing.T) { - _, bus := testStoreSyncService() - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - - wg := sync.WaitGroup{} - var syncResp [] interface{} - - mh, _ := bus.ListenStream(syncChan) - mh.Handle(func(message *model.Message) { - syncResp = append(syncResp, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - id := uuid.New() - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{} {}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp[0].(*model.Response).ErrorMessage, "Invalid UpdateStoreRequest: missing storeId") - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store"}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp[1].(*model.Response).ErrorMessage, "Invalid UpdateStoreRequest: missing itemId") - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store", "itemId": "item1"}, - Id: &id, - }, nil) - wg.Wait() - - assert.Equal(t, syncResp[2].(*model.Response).ErrorMessage, "Cannot update non-existing store: test-store") + _, bus := testStoreSyncService() + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + + wg := sync.WaitGroup{} + var syncResp []interface{} + + mh, _ := bus.ListenStream(syncChan) + mh.Handle(func(message *model.Message) { + syncResp = append(syncResp, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + id := uuid.New() + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp[0].(*model.Response).ErrorMessage, "Invalid UpdateStoreRequest: missing storeId") + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp[1].(*model.Response).ErrorMessage, "Invalid UpdateStoreRequest: missing itemId") + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store", "itemId": "item1"}, + Id: &id, + }, nil) + wg.Wait() + + assert.Equal(t, syncResp[2].(*model.Response).ErrorMessage, "Cannot update non-existing store: test-store") } func TestStoreSyncService_UpdateStore(t *testing.T) { - _, bus := testStoreSyncService() - - store := bus.GetStoreManager().CreateStoreWithType( - "test-store", reflect.TypeOf(&MockStoreItem{})) - store.Populate(map[string]interface{}{ - "item1": &MockStoreItem{From: "test", Message: "test-message"}, - "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, - }) - - syncChan := "transport-store-sync.1" - bus.GetChannelManager().CreateChannel(syncChan) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) - - syncChan2 := "transport-store-sync.2" - bus.GetChannelManager().CreateChannel(syncChan2) - bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) - - wg := sync.WaitGroup{} - var syncResp1 [] interface{} - var syncResp2 [] interface{} - - mh, _ := bus.ListenStream(syncChan) - mh.Handle(func(message *model.Message) { - syncResp1 = append(syncResp1, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - mh2, _ := bus.ListenStream(syncChan2) - mh2.Handle(func(message *model.Message) { - syncResp2 = append(syncResp2, message.Payload) - wg.Done() - }, func(e error) { - assert.Fail(t, "Unexpected error") - }) - - wg.Add(2) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store" }, - }, nil) - bus.SendRequestMessage(syncChan2, &model.Request{ - Request: openStoreRequest, - Payload: map[string]interface{} { "storeId": "test-store" }, - }, nil) - wg.Wait() - - assert.Equal(t, len(syncResp1), 1) - assert.Equal(t, len(syncResp2), 1) - - wg.Add(2) - - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{} { - "storeId": "test-store", - "itemId": "item3", - "newItemValue": map[string]interface{} { - "From": "test3", - "Message": "test-message3", - }}, - }, nil) - - wg.Wait() - - assert.Equal(t, len(syncResp1), 2) - assert.Equal(t, len(syncResp2), 2) - - assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") - assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).StoreId, "test-store") - assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).StoreVersion, int64(2)) - assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).NewItemValue, &MockStoreItem{ - From: "test3", - Message: "test-message3", - }) - - assert.Equal(t, syncResp1[1], syncResp2[1]) - - assert.Equal(t, store.GetValue("item3"), &MockStoreItem{ - From: "test3", - Message: "test-message3", - }) - - wg.Add(2) - store.Remove("item2", "test-remove") - wg.Wait() - - assert.Equal(t, len(syncResp1), 3) - assert.Equal(t, len(syncResp2), 3) - - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).StoreId, "test-store") - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).ItemId, "item2") - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).StoreVersion, int64(3)) - assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).NewItemValue, nil) - - assert.Equal(t, syncResp1[2], syncResp2[2]) - - wg.Add(2) - store.Put("item1", &MockStoreItem{From: "u1", Message: "m1"}, nil) - wg.Wait() - - assert.Equal(t, len(syncResp1), 4) - assert.Equal(t, len(syncResp2), 4) - - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).StoreId, "test-store") - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).ItemId, "item1") - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).StoreVersion, int64(4)) - assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).NewItemValue, - &MockStoreItem{From: "u1", Message: "m1"}) - - assert.Equal(t, syncResp1[3], syncResp2[3]) - - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{} { - "storeId": "test-store", - "itemId": "item4", - "newItemValue": nil}, - }, nil) - - wg.Add(2) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{} { - "storeId": "test-store", - "itemId": "item3", - "newItemValue": nil}, - }, nil) - wg.Wait() - - assert.Equal(t, len(syncResp1), 5) - assert.Equal(t, len(syncResp2), 5) - - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).StoreId, "test-store") - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).ItemId, "item3") - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).StoreVersion, int64(5)) - assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).NewItemValue, nil) - - assert.Equal(t, syncResp1[4], syncResp2[4]) - - assert.Equal(t, store.GetValue("item3"), nil) - - wg.Add(1) - bus.SendRequestMessage(syncChan, &model.Request{ - Request: updateStoreRequest, - Payload: map[string]interface{} { - "storeId": "test-store", - "itemId": "item3", - "newItemValue": "test"}, - }, nil) - wg.Wait() - assert.Equal(t, len(syncResp1), 6) - assert.True(t, strings.HasPrefix(syncResp1[5].(*model.Response).ErrorMessage, - "Cannot deserialize UpdateStoreRequest item value:")) + _, bus := testStoreSyncService() + + store := bus.GetStoreManager().CreateStoreWithType( + "test-store", reflect.TypeOf(&MockStoreItem{})) + store.Populate(map[string]interface{}{ + "item1": &MockStoreItem{From: "test", Message: "test-message"}, + "item2": &MockStoreItem{From: "test2", Message: uuid.New().String()}, + }) + + syncChan := "transport-store-sync.1" + bus.GetChannelManager().CreateChannel(syncChan) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan, nil) + + syncChan2 := "transport-store-sync.2" + bus.GetChannelManager().CreateChannel(syncChan2) + bus.SendMonitorEvent(FabricEndpointSubscribeEvt, syncChan2, nil) + + wg := sync.WaitGroup{} + var syncResp1 []interface{} + var syncResp2 []interface{} + + mh, _ := bus.ListenStream(syncChan) + mh.Handle(func(message *model.Message) { + syncResp1 = append(syncResp1, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + mh2, _ := bus.ListenStream(syncChan2) + mh2.Handle(func(message *model.Message) { + syncResp2 = append(syncResp2, message.Payload) + wg.Done() + }, func(e error) { + assert.Fail(t, "Unexpected error") + }) + + wg.Add(2) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + bus.SendRequestMessage(syncChan2, &model.Request{ + Request: openStoreRequest, + Payload: map[string]interface{}{"storeId": "test-store"}, + }, nil) + wg.Wait() + + assert.Equal(t, len(syncResp1), 1) + assert.Equal(t, len(syncResp2), 1) + + wg.Add(2) + + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{ + "storeId": "test-store", + "itemId": "item3", + "newItemValue": map[string]interface{}{ + "From": "test3", + "Message": "test-message3", + }}, + }, nil) + + wg.Wait() + + assert.Equal(t, len(syncResp1), 2) + assert.Equal(t, len(syncResp2), 2) + + assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") + assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).StoreId, "test-store") + assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).StoreVersion, int64(2)) + assert.Equal(t, syncResp1[1].(*model.UpdateStoreResponse).NewItemValue, &MockStoreItem{ + From: "test3", + Message: "test-message3", + }) + + assert.Equal(t, syncResp1[1], syncResp2[1]) + + assert.Equal(t, store.GetValue("item3"), &MockStoreItem{ + From: "test3", + Message: "test-message3", + }) + + wg.Add(2) + store.Remove("item2", "test-remove") + wg.Wait() + + assert.Equal(t, len(syncResp1), 3) + assert.Equal(t, len(syncResp2), 3) + + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).StoreId, "test-store") + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).ItemId, "item2") + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).StoreVersion, int64(3)) + assert.Equal(t, syncResp1[2].(*model.UpdateStoreResponse).NewItemValue, nil) + + assert.Equal(t, syncResp1[2], syncResp2[2]) + + wg.Add(2) + store.Put("item1", &MockStoreItem{From: "u1", Message: "m1"}, nil) + wg.Wait() + + assert.Equal(t, len(syncResp1), 4) + assert.Equal(t, len(syncResp2), 4) + + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).StoreId, "test-store") + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).ItemId, "item1") + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).StoreVersion, int64(4)) + assert.Equal(t, syncResp1[3].(*model.UpdateStoreResponse).NewItemValue, + &MockStoreItem{From: "u1", Message: "m1"}) + + assert.Equal(t, syncResp1[3], syncResp2[3]) + + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{ + "storeId": "test-store", + "itemId": "item4", + "newItemValue": nil}, + }, nil) + + wg.Add(2) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{ + "storeId": "test-store", + "itemId": "item3", + "newItemValue": nil}, + }, nil) + wg.Wait() + + assert.Equal(t, len(syncResp1), 5) + assert.Equal(t, len(syncResp2), 5) + + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).ResponseType, "updateStoreResponse") + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).StoreId, "test-store") + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).ItemId, "item3") + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).StoreVersion, int64(5)) + assert.Equal(t, syncResp1[4].(*model.UpdateStoreResponse).NewItemValue, nil) + + assert.Equal(t, syncResp1[4], syncResp2[4]) + + assert.Equal(t, store.GetValue("item3"), nil) + + wg.Add(1) + bus.SendRequestMessage(syncChan, &model.Request{ + Request: updateStoreRequest, + Payload: map[string]interface{}{ + "storeId": "test-store", + "itemId": "item3", + "newItemValue": "test"}, + }, nil) + wg.Wait() + assert.Equal(t, len(syncResp1), 6) + assert.True(t, strings.HasPrefix(syncResp1[5].(*model.Response).ErrorMessage, + "Cannot deserialize UpdateStoreRequest item value:")) } diff --git a/bus/transaction.go b/bus/transaction.go index 06f85df..f363986 100644 --- a/bus/transaction.go +++ b/bus/transaction.go @@ -4,263 +4,265 @@ package bus import ( - "fmt" - "github.com/google/uuid" - "github.com/vmware/transport-go/model" - "sync" + "fmt" + "github.com/google/uuid" + "github.com/vmware/transport-go/model" + "sync" ) type transactionType int + const ( - asyncTransaction transactionType = iota - syncTransaction + asyncTransaction transactionType = iota + syncTransaction ) type BusTransactionReadyFunction func(responses []*model.Message) type BusTransaction interface { - // Sends a request to a channel as a part of this transaction. - SendRequest(channel string, payload interface{}) error - // Wait for a store to be initialized as a part of this transaction. - WaitForStoreReady(storeName string) error - // Registers a new complete handler. Once all responses to requests have been received, - // the transaction is complete. - OnComplete(completeHandler BusTransactionReadyFunction) error - // Register a new error handler. If an error is thrown by any of the responders, the transaction - // is aborted and the error sent to the registered errorHandlers. - OnError(errorHandler MessageErrorFunction) error - // Commit the transaction, all requests will be sent and will wait for responses. - // Once all the responses are in, onComplete handlers will be called with the responses. - Commit() error + // Sends a request to a channel as a part of this transaction. + SendRequest(channel string, payload interface{}) error + // Wait for a store to be initialized as a part of this transaction. + WaitForStoreReady(storeName string) error + // Registers a new complete handler. Once all responses to requests have been received, + // the transaction is complete. + OnComplete(completeHandler BusTransactionReadyFunction) error + // Register a new error handler. If an error is thrown by any of the responders, the transaction + // is aborted and the error sent to the registered errorHandlers. + OnError(errorHandler MessageErrorFunction) error + // Commit the transaction, all requests will be sent and will wait for responses. + // Once all the responses are in, onComplete handlers will be called with the responses. + Commit() error } type transactionState int + const ( - uncommittedState transactionState = iota - committedState - completedState - abortedState + uncommittedState transactionState = iota + committedState + completedState + abortedState ) type busTransactionRequest struct { - requestIndex int - storeName string - channelName string - payload interface{} + requestIndex int + storeName string + channelName string + payload interface{} } type busTransaction struct { - transactionType transactionType - state transactionState - lock sync.Mutex - requests []*busTransactionRequest - responses []*model.Message - onCompleteHandlers []BusTransactionReadyFunction - onErrorHandlers []MessageErrorFunction - bus EventBus - completedRequests int + transactionType transactionType + state transactionState + lock sync.Mutex + requests []*busTransactionRequest + responses []*model.Message + onCompleteHandlers []BusTransactionReadyFunction + onErrorHandlers []MessageErrorFunction + bus EventBus + completedRequests int } func newBusTransaction(bus EventBus, transactionType transactionType) BusTransaction { - transaction := new(busTransaction) + transaction := new(busTransaction) - transaction.bus = bus - transaction.state = uncommittedState - transaction.transactionType = transactionType - transaction.requests = make([]*busTransactionRequest, 0) - transaction.onCompleteHandlers = make([]BusTransactionReadyFunction, 0) - transaction.onErrorHandlers = make([]MessageErrorFunction, 0) - transaction.completedRequests = 0 + transaction.bus = bus + transaction.state = uncommittedState + transaction.transactionType = transactionType + transaction.requests = make([]*busTransactionRequest, 0) + transaction.onCompleteHandlers = make([]BusTransactionReadyFunction, 0) + transaction.onErrorHandlers = make([]MessageErrorFunction, 0) + transaction.completedRequests = 0 - return transaction + return transaction } func (tr *busTransaction) checkUncommittedState() error { - if tr.state != uncommittedState { - return fmt.Errorf("transaction has already been committed") - } - return nil + if tr.state != uncommittedState { + return fmt.Errorf("transaction has already been committed") + } + return nil } func (tr *busTransaction) SendRequest(channel string, payload interface{}) error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - tr.requests = append(tr.requests, &busTransactionRequest{ - channelName: channel, - payload: payload, - requestIndex: len(tr.requests), - }) + tr.requests = append(tr.requests, &busTransactionRequest{ + channelName: channel, + payload: payload, + requestIndex: len(tr.requests), + }) - return nil + return nil } func (tr *busTransaction) WaitForStoreReady(storeName string) error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - if tr.bus.GetStoreManager().GetStore(storeName) == nil { - return fmt.Errorf("cannot find store '%s'", storeName) - } + if tr.bus.GetStoreManager().GetStore(storeName) == nil { + return fmt.Errorf("cannot find store '%s'", storeName) + } - tr.requests = append(tr.requests, &busTransactionRequest{ - storeName: storeName, - requestIndex: len(tr.requests), - }) + tr.requests = append(tr.requests, &busTransactionRequest{ + storeName: storeName, + requestIndex: len(tr.requests), + }) - return nil + return nil } func (tr *busTransaction) OnComplete(completeHandler BusTransactionReadyFunction) error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - tr.onCompleteHandlers = append(tr.onCompleteHandlers, completeHandler) - return nil + tr.onCompleteHandlers = append(tr.onCompleteHandlers, completeHandler) + return nil } func (tr *busTransaction) OnError(errorHandler MessageErrorFunction) error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - tr.onErrorHandlers = append(tr.onErrorHandlers, errorHandler) - return nil + tr.onErrorHandlers = append(tr.onErrorHandlers, errorHandler) + return nil } func (tr *busTransaction) Commit() error { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if err := tr.checkUncommittedState(); err != nil { - return err - } + if err := tr.checkUncommittedState(); err != nil { + return err + } - if len(tr.requests) == 0 { - return fmt.Errorf("cannot commit empty transaction") - } + if len(tr.requests) == 0 { + return fmt.Errorf("cannot commit empty transaction") + } - tr.state = committedState + tr.state = committedState - // init responses slice - tr.responses = make([]*model.Message, len(tr.requests)) + // init responses slice + tr.responses = make([]*model.Message, len(tr.requests)) - if tr.transactionType == asyncTransaction { - tr.startAsyncTransaction() - } else { - tr.startSyncTransaction() - } + if tr.transactionType == asyncTransaction { + tr.startAsyncTransaction() + } else { + tr.startSyncTransaction() + } - return nil + return nil } func (tr *busTransaction) startSyncTransaction() { - tr.executeRequest(tr.requests[0]) + tr.executeRequest(tr.requests[0]) } func (tr *busTransaction) executeRequest(request *busTransactionRequest) { - if request.storeName != "" { - tr.waitForStore(request) - } else { - tr.sendRequest(request) - } + if request.storeName != "" { + tr.waitForStore(request) + } else { + tr.sendRequest(request) + } } func (tr *busTransaction) startAsyncTransaction() { - for _, req := range tr.requests { - tr.executeRequest(req) - } + for _, req := range tr.requests { + tr.executeRequest(req) + } } func (tr *busTransaction) sendRequest(req *busTransactionRequest) { - reqId := uuid.New() + reqId := uuid.New() - mh, err := tr.bus.ListenOnceForDestination(req.channelName, &reqId) - if err != nil { - tr.onTransactionError(err) - return - } + mh, err := tr.bus.ListenOnceForDestination(req.channelName, &reqId) + if err != nil { + tr.onTransactionError(err) + return + } - mh.Handle(func(message *model.Message) { - tr.onTransactionRequestSuccess(req, message) - }, func(e error) { - tr.onTransactionError(e) - }) + mh.Handle(func(message *model.Message) { + tr.onTransactionRequestSuccess(req, message) + }, func(e error) { + tr.onTransactionError(e) + }) - tr.bus.SendRequestMessage(req.channelName, req.payload, &reqId) + tr.bus.SendRequestMessage(req.channelName, req.payload, &reqId) } func (tr *busTransaction) onTransactionError(err error) { - tr.lock.Lock() - defer tr.lock.Unlock() + tr.lock.Lock() + defer tr.lock.Unlock() - if tr.state == abortedState { - return - } + if tr.state == abortedState { + return + } - tr.state = abortedState - for _, errorHandler := range tr.onErrorHandlers { - go errorHandler(err) - } + tr.state = abortedState + for _, errorHandler := range tr.onErrorHandlers { + go errorHandler(err) + } } func (tr *busTransaction) waitForStore(req *busTransactionRequest) { - store := tr.bus.GetStoreManager().GetStore(req.storeName) - if store == nil { - tr.onTransactionError(fmt.Errorf("cannot find store '%s'", req.storeName)) - return - } - store.WhenReady(func() { - tr.onTransactionRequestSuccess(req, &model.Message{ - Direction: model.ResponseDir, - Payload: store.AllValuesAsMap(), - }) - }) + store := tr.bus.GetStoreManager().GetStore(req.storeName) + if store == nil { + tr.onTransactionError(fmt.Errorf("cannot find store '%s'", req.storeName)) + return + } + store.WhenReady(func() { + tr.onTransactionRequestSuccess(req, &model.Message{ + Direction: model.ResponseDir, + Payload: store.AllValuesAsMap(), + }) + }) } func (tr *busTransaction) onTransactionRequestSuccess(req *busTransactionRequest, message *model.Message) { - var triggerOnCompleteHandler = false - tr.lock.Lock() - - if tr.state == abortedState { - tr.lock.Unlock() - return - } - - tr.responses[req.requestIndex] = message - tr.completedRequests++ - - if tr.completedRequests == len(tr.requests) { - tr.state = completedState - triggerOnCompleteHandler = true - } - - tr.lock.Unlock() - - if triggerOnCompleteHandler { - for _, completeHandler := range tr.onCompleteHandlers { - go completeHandler(tr.responses) - } - return - } - - // If this is a sync transaction execute the next request - if tr.transactionType == syncTransaction && req.requestIndex < len(tr.requests) - 1 { - tr.executeRequest(tr.requests[req.requestIndex + 1]) - } + var triggerOnCompleteHandler = false + tr.lock.Lock() + + if tr.state == abortedState { + tr.lock.Unlock() + return + } + + tr.responses[req.requestIndex] = message + tr.completedRequests++ + + if tr.completedRequests == len(tr.requests) { + tr.state = completedState + triggerOnCompleteHandler = true + } + + tr.lock.Unlock() + + if triggerOnCompleteHandler { + for _, completeHandler := range tr.onCompleteHandlers { + go completeHandler(tr.responses) + } + return + } + + // If this is a sync transaction execute the next request + if tr.transactionType == syncTransaction && req.requestIndex < len(tr.requests)-1 { + tr.executeRequest(tr.requests[req.requestIndex+1]) + } } diff --git a/bus/transaction_test.go b/bus/transaction_test.go index 838c753..91f3a01 100644 --- a/bus/transaction_test.go +++ b/bus/transaction_test.go @@ -4,322 +4,319 @@ package bus import ( - "errors" - "github.com/stretchr/testify/assert" - "github.com/vmware/transport-go/model" - "sync" - "sync/atomic" - "testing" + "errors" + "github.com/stretchr/testify/assert" + "github.com/vmware/transport-go/model" + "sync" + "sync/atomic" + "testing" ) func TestBusTransaction_OnCompleteSync(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - bus.GetChannelManager().CreateChannel("test-channel") + bus.GetChannelManager().CreateChannel("test-channel") - var channelReqMessage *model.Message - var requestCounter = 0 + var channelReqMessage *model.Message + var requestCounter = 0 - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - mh,_ := bus.ListenRequestStream("test-channel") - mh.Handle(func(message *model.Message) { - requestCounter++ - channelReqMessage = message - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) + mh, _ := bus.ListenRequestStream("test-channel") + mh.Handle(func(message *model.Message) { + requestCounter++ + channelReqMessage = message + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) - tr := newBusTransaction(bus, syncTransaction) + tr := newBusTransaction(bus, syncTransaction) - bus.GetStoreManager().CreateStore("testStore") - assert.Nil(t, tr.WaitForStoreReady("testStore")) - assert.Nil(t, tr.SendRequest("test-channel", "sample-request")) + bus.GetStoreManager().CreateStore("testStore") + assert.Nil(t, tr.WaitForStoreReady("testStore")) + assert.Nil(t, tr.SendRequest("test-channel", "sample-request")) - var completeCounter int64 + var completeCounter int64 - tr.OnComplete(func(responses []*model.Message) { - atomic.AddInt64(&completeCounter, 1) - wg.Done() - }) + tr.OnComplete(func(responses []*model.Message) { + atomic.AddInt64(&completeCounter, 1) + wg.Done() + }) - tr.OnError(func(e error) { - assert.Fail(t, "unexpected error") - }) + tr.OnError(func(e error) { + assert.Fail(t, "unexpected error") + }) - tr.OnComplete(func(responses []*model.Message) { - atomic.AddInt64(&completeCounter, 1) - assert.Equal(t, len(responses), 2) - assert.Equal(t, responses[1].Channel, "test-channel") - assert.Equal(t, responses[1].Payload, "sample-response") - wg.Done() - }) + tr.OnComplete(func(responses []*model.Message) { + atomic.AddInt64(&completeCounter, 1) + assert.Equal(t, len(responses), 2) + assert.Equal(t, responses[1].Channel, "test-channel") + assert.Equal(t, responses[1].Payload, "sample-response") + wg.Done() + }) - assert.Equal(t, requestCounter, 0) + assert.Equal(t, requestCounter, 0) - wg.Add(1) + wg.Add(1) - assert.Nil(t, tr.Commit()) + assert.Nil(t, tr.Commit()) - go bus.GetStoreManager().CreateStore("testStore").Initialize() + go bus.GetStoreManager().CreateStore("testStore").Initialize() - wg.Wait() + wg.Wait() - assert.Equal(t, requestCounter, 1) - assert.NotNil(t, channelReqMessage) + assert.Equal(t, requestCounter, 1) + assert.NotNil(t, channelReqMessage) - assert.Equal(t, channelReqMessage.Payload, "sample-request") + assert.Equal(t, channelReqMessage.Payload, "sample-request") - for i := 0; i < 50; i++ { - bus.SendResponseMessage("test-channel", "general-message", nil) - } + for i := 0; i < 50; i++ { + bus.SendResponseMessage("test-channel", "general-message", nil) + } - assert.Equal(t, completeCounter, int64(0)) + assert.Equal(t, completeCounter, int64(0)) - wg.Add(2) - bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId) + wg.Add(2) + bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId) - wg.Wait() + wg.Wait() - assert.Equal(t, tr.(*busTransaction).state, completedState) + assert.Equal(t, tr.(*busTransaction).state, completedState) - assert.Equal(t, completeCounter, int64(2)) + assert.Equal(t, completeCounter, int64(2)) - bus.SendResponseMessage("test-channel", "sample-response2", channelReqMessage.DestinationId) - assert.Equal(t, completeCounter, int64(2)) + bus.SendResponseMessage("test-channel", "sample-response2", channelReqMessage.DestinationId) + assert.Equal(t, completeCounter, int64(2)) } func TestBusTransaction_OnCompleteErrorHandling(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - tr := newBusTransaction(bus, syncTransaction) + tr := newBusTransaction(bus, syncTransaction) - assert.EqualError(t, tr.Commit(), "cannot commit empty transaction") + assert.EqualError(t, tr.Commit(), "cannot commit empty transaction") - assert.Equal(t, tr.(*busTransaction).state, uncommittedState) + assert.Equal(t, tr.(*busTransaction).state, uncommittedState) - bus.GetStoreManager().CreateStore("testStore") - assert.Nil(t, tr.WaitForStoreReady("testStore")) + bus.GetStoreManager().CreateStore("testStore") + assert.Nil(t, tr.WaitForStoreReady("testStore")) - assert.EqualError(t, tr.WaitForStoreReady("invalid-store"), "cannot find store 'invalid-store'") + assert.EqualError(t, tr.WaitForStoreReady("invalid-store"), "cannot find store 'invalid-store'") - tr.Commit() + tr.Commit() - assert.EqualError(t, tr.OnComplete(func(responses []*model.Message) {}), "transaction has already been committed") + assert.EqualError(t, tr.OnComplete(func(responses []*model.Message) {}), "transaction has already been committed") - assert.Equal(t, tr.(*busTransaction).state, committedState) - assert.EqualError(t, tr.Commit(), "transaction has already been committed") + assert.Equal(t, tr.(*busTransaction).state, committedState) + assert.EqualError(t, tr.Commit(), "transaction has already been committed") - assert.EqualError(t, tr.WaitForStoreReady("test"), "transaction has already been committed") - assert.EqualError(t, tr.SendRequest("test", "test"), "transaction has already been committed") + assert.EqualError(t, tr.WaitForStoreReady("test"), "transaction has already been committed") + assert.EqualError(t, tr.SendRequest("test", "test"), "transaction has already been committed") } func TestBusTransaction_OnErrorSync(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - tr := newBusTransaction(bus, syncTransaction) + tr := newBusTransaction(bus, syncTransaction) - bus.GetStoreManager().CreateStore("testStore") - assert.Nil(t, tr.WaitForStoreReady("testStore")) + bus.GetStoreManager().CreateStore("testStore") + assert.Nil(t, tr.WaitForStoreReady("testStore")) - bus.GetChannelManager().CreateChannel("test-channel") + bus.GetChannelManager().CreateChannel("test-channel") - var channelReqMessage *model.Message - var requestCounter = 0 + var channelReqMessage *model.Message + var requestCounter = 0 - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - mh,_ := bus.ListenRequestStream("test-channel") - mh.Handle(func(message *model.Message) { - requestCounter++ - channelReqMessage = message - wg.Done() - }, func(e error) { - }) + mh, _ := bus.ListenRequestStream("test-channel") + mh.Handle(func(message *model.Message) { + requestCounter++ + channelReqMessage = message + wg.Done() + }, func(e error) { + }) - tr.SendRequest("test-channel", "sample-request") - tr.SendRequest("test-channel", "sample-request") - tr.SendRequest("test-channel", "sample-request") + tr.SendRequest("test-channel", "sample-request") + tr.SendRequest("test-channel", "sample-request") + tr.SendRequest("test-channel", "sample-request") + tr.OnComplete(func(responses []*model.Message) { + assert.Fail(t, "invalid state") + }) - tr.OnComplete(func(responses []*model.Message) { - assert.Fail(t, "invalid state") - }) + var errorHandlerCount int64 = 0 + tr.OnError(func(e error) { + atomic.AddInt64(&errorHandlerCount, 1) + wg.Done() + }) - var errorHandlerCount int64 = 0 - tr.OnError(func(e error) { - atomic.AddInt64(&errorHandlerCount, 1) - wg.Done() - }) + tr.OnError(func(e error) { + atomic.AddInt64(&errorHandlerCount, 1) + assert.EqualError(t, e, "test-error") + wg.Done() + }) - tr.OnError(func(e error) { - atomic.AddInt64(&errorHandlerCount, 1) - assert.EqualError(t, e, "test-error") - wg.Done() - }) + tr.Commit() - tr.Commit() + assert.Equal(t, tr.(*busTransaction).state, committedState) - assert.Equal(t, tr.(*busTransaction).state, committedState) + wg.Add(1) - wg.Add(1) + bus.GetStoreManager().GetStore("testStore").Initialize() - bus.GetStoreManager().GetStore("testStore").Initialize() + wg.Wait() - wg.Wait() + assert.Equal(t, requestCounter, 1) + assert.NotNil(t, channelReqMessage) - assert.Equal(t, requestCounter, 1) - assert.NotNil(t, channelReqMessage) + wg.Add(2) + bus.SendErrorMessage("test-channel", errors.New("test-error"), channelReqMessage.DestinationId) - wg.Add(2) - bus.SendErrorMessage("test-channel", errors.New("test-error"), channelReqMessage.DestinationId) + wg.Wait() - wg.Wait() + assert.Equal(t, tr.(*busTransaction).state, abortedState) - assert.Equal(t, tr.(*busTransaction).state, abortedState) + assert.Equal(t, requestCounter, 1) + assert.Equal(t, errorHandlerCount, int64(2)) - assert.Equal(t, requestCounter, 1) - assert.Equal(t, errorHandlerCount, int64(2)) - - assert.EqualError(t, tr.Commit(), "transaction has already been committed") + assert.EqualError(t, tr.Commit(), "transaction has already been committed") } func TestBusTransaction_OnCompleteAsync(t *testing.T) { - bus := newTestEventBus() + bus := newTestEventBus() - bus.GetChannelManager().CreateChannel("test-channel") + bus.GetChannelManager().CreateChannel("test-channel") - var channelReqMessage *model.Message - var requestCounter = 0 + var channelReqMessage *model.Message + var requestCounter = 0 - wg := sync.WaitGroup{} + wg := sync.WaitGroup{} - mh,_ := bus.ListenRequestStream("test-channel") - mh.Handle(func(message *model.Message) { - requestCounter++ - channelReqMessage = message - wg.Done() - }, func(e error) { - assert.Fail(t, "unexpected error") - }) + mh, _ := bus.ListenRequestStream("test-channel") + mh.Handle(func(message *model.Message) { + requestCounter++ + channelReqMessage = message + wg.Done() + }, func(e error) { + assert.Fail(t, "unexpected error") + }) - tr := newBusTransaction(bus, asyncTransaction) + tr := newBusTransaction(bus, asyncTransaction) - bus.GetStoreManager().CreateStore("testStore") - assert.Nil(t, tr.WaitForStoreReady("testStore")) - assert.Nil(t, tr.WaitForStoreReady("testStore")) - bus.GetStoreManager().CreateStore("testStore2") - assert.Nil(t, tr.WaitForStoreReady("testStore2")) - bus.GetStoreManager().CreateStore("testStore3") - assert.Nil(t, tr.WaitForStoreReady("testStore3")) - assert.Nil(t, tr.SendRequest("test-channel", "sample-request")) + bus.GetStoreManager().CreateStore("testStore") + assert.Nil(t, tr.WaitForStoreReady("testStore")) + assert.Nil(t, tr.WaitForStoreReady("testStore")) + bus.GetStoreManager().CreateStore("testStore2") + assert.Nil(t, tr.WaitForStoreReady("testStore2")) + bus.GetStoreManager().CreateStore("testStore3") + assert.Nil(t, tr.WaitForStoreReady("testStore3")) + assert.Nil(t, tr.SendRequest("test-channel", "sample-request")) - var completeCounter int64 + var completeCounter int64 - tr.OnComplete(func(responses []*model.Message) { - atomic.AddInt64(&completeCounter, 1) - wg.Done() - }) + tr.OnComplete(func(responses []*model.Message) { + atomic.AddInt64(&completeCounter, 1) + wg.Done() + }) - tr.OnComplete(func(responses []*model.Message) { - atomic.AddInt64(&completeCounter, 1) - assert.Equal(t, len(responses), 5) - assert.Equal(t, responses[4].Channel, "test-channel") - assert.Equal(t, responses[4].Payload, "sample-response") - wg.Done() - }) + tr.OnComplete(func(responses []*model.Message) { + atomic.AddInt64(&completeCounter, 1) + assert.Equal(t, len(responses), 5) + assert.Equal(t, responses[4].Channel, "test-channel") + assert.Equal(t, responses[4].Payload, "sample-response") + wg.Done() + }) - wg.Add(1) - assert.Nil(t, tr.Commit()) - wg.Wait() + wg.Add(1) + assert.Nil(t, tr.Commit()) + wg.Wait() - assert.NotNil(t, bus.GetStoreManager().GetStore("testStore")) - assert.NotNil(t, bus.GetStoreManager().GetStore("testStore2")) - assert.NotNil(t, bus.GetStoreManager().GetStore("testStore3")) - assert.Equal(t, requestCounter, 1) - assert.NotNil(t, channelReqMessage) - assert.Equal(t, channelReqMessage.Payload, "sample-request") + assert.NotNil(t, bus.GetStoreManager().GetStore("testStore")) + assert.NotNil(t, bus.GetStoreManager().GetStore("testStore2")) + assert.NotNil(t, bus.GetStoreManager().GetStore("testStore3")) + assert.Equal(t, requestCounter, 1) + assert.NotNil(t, channelReqMessage) + assert.Equal(t, channelReqMessage.Payload, "sample-request") - for i := 0; i < 20; i++ { - bus.SendResponseMessage("test-channel", "general-message", nil) - } + for i := 0; i < 20; i++ { + bus.SendResponseMessage("test-channel", "general-message", nil) + } - assert.Equal(t, completeCounter, int64(0)) + assert.Equal(t, completeCounter, int64(0)) - wg.Add(2) + wg.Add(2) - bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId) - bus.GetStoreManager().GetStore("testStore").Initialize() - bus.GetStoreManager().GetStore("testStore2").Initialize() - bus.GetStoreManager().GetStore("testStore3").Initialize() + bus.SendResponseMessage("test-channel", "sample-response", channelReqMessage.DestinationId) + bus.GetStoreManager().GetStore("testStore").Initialize() + bus.GetStoreManager().GetStore("testStore2").Initialize() + bus.GetStoreManager().GetStore("testStore3").Initialize() - wg.Wait() + wg.Wait() - assert.Equal(t, completeCounter, int64(2)) + assert.Equal(t, completeCounter, int64(2)) } func TestBusTransaction_OnErrorAsync(t *testing.T) { - bus := newTestEventBus() - - tr := newBusTransaction(bus, asyncTransaction) - - bus.GetChannelManager().CreateChannel("test-channel") - bus.GetChannelManager().CreateChannel("test-channel2") + bus := newTestEventBus() - var channelReqMessage, channelReqMessage2 *model.Message + tr := newBusTransaction(bus, asyncTransaction) + bus.GetChannelManager().CreateChannel("test-channel") + bus.GetChannelManager().CreateChannel("test-channel2") - wg := sync.WaitGroup{} + var channelReqMessage, channelReqMessage2 *model.Message - mh,_ := bus.ListenRequestStream("test-channel") - mh.Handle(func(message *model.Message) { - channelReqMessage = message - wg.Done() - }, func(e error) { - }) + wg := sync.WaitGroup{} - mh2,_ := bus.ListenRequestStream("test-channel2") - mh2.Handle(func(message *model.Message) { - channelReqMessage2 = message - wg.Done() - }, func(e error) { - }) + mh, _ := bus.ListenRequestStream("test-channel") + mh.Handle(func(message *model.Message) { + channelReqMessage = message + wg.Done() + }, func(e error) { + }) + mh2, _ := bus.ListenRequestStream("test-channel2") + mh2.Handle(func(message *model.Message) { + channelReqMessage2 = message + wg.Done() + }, func(e error) { + }) - tr.OnComplete(func(responses []*model.Message) { - assert.Fail(t, "invalid state") - }) + tr.OnComplete(func(responses []*model.Message) { + assert.Fail(t, "invalid state") + }) - var errorHandlerCount int64 = 0 - tr.OnError(func(e error) { - atomic.AddInt64(&errorHandlerCount, 1) - assert.EqualError(t, e, "test-error") - wg.Done() - }) + var errorHandlerCount int64 = 0 + tr.OnError(func(e error) { + atomic.AddInt64(&errorHandlerCount, 1) + assert.EqualError(t, e, "test-error") + wg.Done() + }) - tr.SendRequest("test-channel", "sample-request") - tr.SendRequest("test-channel2", "sample-request2") + tr.SendRequest("test-channel", "sample-request") + tr.SendRequest("test-channel2", "sample-request2") - wg.Add(2) - tr.Commit() - wg.Wait() + wg.Add(2) + tr.Commit() + wg.Wait() - wg.Add(1) - bus.SendErrorMessage("test-channel2", errors.New("test-error"), channelReqMessage2.DestinationId) + wg.Add(1) + bus.SendErrorMessage("test-channel2", errors.New("test-error"), channelReqMessage2.DestinationId) - wg.Wait() + wg.Wait() - assert.Equal(t, errorHandlerCount, int64(1)) + assert.Equal(t, errorHandlerCount, int64(1)) - for i := 0; i < 50; i++ { - bus.SendErrorMessage("test-channel", errors.New("test-error-2"), channelReqMessage.DestinationId) - } + for i := 0; i < 50; i++ { + bus.SendErrorMessage("test-channel", errors.New("test-error-2"), channelReqMessage.DestinationId) + } - assert.Equal(t, errorHandlerCount, int64(1)) + assert.Equal(t, errorHandlerCount, int64(1)) } diff --git a/log/logger.go b/log/logger.go index ff74533..9b326ec 100644 --- a/log/logger.go +++ b/log/logger.go @@ -4,10 +4,10 @@ package log import ( - "fmt" - "github.com/fatih/color" - "os" - "strings" + "fmt" + "github.com/fatih/color" + "os" + "strings" ) // These flags have to be set from the "opts" module in each sewing-machine tool @@ -21,60 +21,60 @@ var Version = "" // Print warnings func Warn(format string, arg ...interface{}) { - color.NoColor = false - color.Set(color.FgHiMagenta) - if !WarnFlag { - fmt.Printf("⚠️🚨 WARNING: "+format, arg...) - } - color.Unset() + color.NoColor = false + color.Set(color.FgHiMagenta) + if !WarnFlag { + fmt.Printf("⚠️🚨 WARNING: "+format, arg...) + } + color.Unset() } // Print traces func Trace(format string, arg ...interface{}) { - color.NoColor = false - color.Set(color.FgCyan) - color.Set(color.Faint) - if TraceFlag { - fmt.Printf(format, arg...) - } - color.Unset() + color.NoColor = false + color.Set(color.FgCyan) + color.Set(color.Faint) + if TraceFlag { + fmt.Printf(format, arg...) + } + color.Unset() } // Print debug func Debug(format string, arg ...interface{}) { - if DebugFlag { - fmt.Printf(format, arg...) - } + if DebugFlag { + fmt.Printf(format, arg...) + } } // Print verbose func Verbose(format string, arg ...interface{}) { - color.NoColor = false - color.Set(color.FgHiMagenta) - if VerboseFlag { - fmt.Printf(format, arg...) - } - color.Unset() + color.NoColor = false + color.Set(color.FgHiMagenta) + if VerboseFlag { + fmt.Printf(format, arg...) + } + color.Unset() } // Catchable Panic func Panicf(format string, args ...interface{}) { - color.NoColor = false - color.Set(color.FgRed) - color.Set(color.Bold) + color.NoColor = false + color.Set(color.FgRed) + color.Set(color.Bold) - fmt.Printf("❌ FATAL: "+format, args...) - color.Unset() - if !RecoverOnError { - os.Exit(4) - } + fmt.Printf("❌ FATAL: "+format, args...) + color.Unset() + if !RecoverOnError { + os.Exit(4) + } } func SetVersion(version string) { - Version = version - if strings.Contains(Version, "-") { - Version = Version[:strings.Index(version, "-")] - } else { - Version = "v2.285" - } + Version = version + if strings.Contains(Version, "-") { + Version = Version[:strings.Index(version, "-")] + } else { + Version = "v2.285" + } } diff --git a/model/store_responses.go b/model/store_responses.go index f426cf4..fbb5a44 100644 --- a/model/store_responses.go +++ b/model/store_responses.go @@ -4,39 +4,39 @@ package model type StoreContentResponse struct { - Items map[string]interface{} `json:"items"` - ResponseType string `json:"responseType"` // should be "storeContentResponse" - StoreId string `json:"storeId"` - StoreVersion int64 `json:"storeVersion"` + Items map[string]interface{} `json:"items"` + ResponseType string `json:"responseType"` // should be "storeContentResponse" + StoreId string `json:"storeId"` + StoreVersion int64 `json:"storeVersion"` } func NewStoreContentResponse( - storeId string, items map[string]interface{}, storeVersion int64) *StoreContentResponse { + storeId string, items map[string]interface{}, storeVersion int64) *StoreContentResponse { - return &StoreContentResponse{ - ResponseType: "storeContentResponse", - StoreId: storeId, - Items: items, - StoreVersion: storeVersion, - } + return &StoreContentResponse{ + ResponseType: "storeContentResponse", + StoreId: storeId, + Items: items, + StoreVersion: storeVersion, + } } type UpdateStoreResponse struct { - ItemId string `json:"itemId"` - NewItemValue interface{} `json:"newItemValue"` - ResponseType string `json:"responseType"` // should be "updateStoreResponse" - StoreId string `json:"storeId"` - StoreVersion int64 `json:"storeVersion"` + ItemId string `json:"itemId"` + NewItemValue interface{} `json:"newItemValue"` + ResponseType string `json:"responseType"` // should be "updateStoreResponse" + StoreId string `json:"storeId"` + StoreVersion int64 `json:"storeVersion"` } func NewUpdateStoreResponse( - storeId string, itemId string, newValue interface{},storeVersion int64) *UpdateStoreResponse { + storeId string, itemId string, newValue interface{}, storeVersion int64) *UpdateStoreResponse { - return &UpdateStoreResponse{ - ResponseType: "updateStoreResponse", - StoreId: storeId, - StoreVersion: storeVersion, - ItemId: itemId, - NewItemValue: newValue, - } -} \ No newline at end of file + return &UpdateStoreResponse{ + ResponseType: "updateStoreResponse", + StoreId: storeId, + StoreVersion: storeVersion, + ItemId: itemId, + NewItemValue: newValue, + } +} diff --git a/model/util.go b/model/util.go index c645a47..e2509b7 100644 --- a/model/util.go +++ b/model/util.go @@ -4,35 +4,35 @@ package model import ( - "encoding/json" - "reflect" + "encoding/json" + "reflect" ) func ConvertValueToType(value interface{}, targetType reflect.Type) (interface{}, error) { - if targetType == nil { - return value, nil - } + if targetType == nil { + return value, nil + } - itemType := targetType - var isTargetTypePointer bool + itemType := targetType + var isTargetTypePointer bool - if itemType.Kind() == reflect.Ptr { - isTargetTypePointer = true - itemType = itemType.Elem() - } + if itemType.Kind() == reflect.Ptr { + isTargetTypePointer = true + itemType = itemType.Elem() + } - decodedValuePtr := reflect.New(itemType).Interface() + decodedValuePtr := reflect.New(itemType).Interface() - marshaledValue, _ := json.Marshal(value) - decodeErr := json.Unmarshal(marshaledValue, decodedValuePtr) + marshaledValue, _ := json.Marshal(value) + decodeErr := json.Unmarshal(marshaledValue, decodedValuePtr) - if decodeErr != nil { - return nil, decodeErr - } + if decodeErr != nil { + return nil, decodeErr + } - if isTargetTypePointer { - return decodedValuePtr, nil - } else { - return reflect.ValueOf(decodedValuePtr).Elem().Interface(), nil - } + if isTargetTypePointer { + return decodedValuePtr, nil + } else { + return reflect.ValueOf(decodedValuePtr).Elem().Interface(), nil + } } diff --git a/plank/pkg/metrics/pageview_metric.go b/plank/pkg/metrics/pageview_metric.go index cc3dc85..cb1dce5 100644 --- a/plank/pkg/metrics/pageview_metric.go +++ b/plank/pkg/metrics/pageview_metric.go @@ -1,8 +1,8 @@ // Copyright 2019-2021 VMware, Inc. // SPDX-License-Identifier: BSD-2-Clause -// +build !js -// +build !wasm +//go:build !js && !wasm +// +build !js,!wasm package metrics diff --git a/plank/pkg/middleware/cache_control.go b/plank/pkg/middleware/cache_control.go index 51c84de..0e53e92 100644 --- a/plank/pkg/middleware/cache_control.go +++ b/plank/pkg/middleware/cache_control.go @@ -17,8 +17,8 @@ import ( // for the matching pattern and the compiled glob pattern for use in runtime. see https://github.com/gobwas/glob for // detailed examples of glob patterns. type CacheControlRulePair struct { - GlobPattern string - CacheControlRule string + GlobPattern string + CacheControlRule string CompiledGlobPattern glob.Glob } @@ -27,8 +27,8 @@ type CacheControlRulePair struct { func NewCacheControlRulePair(globPattern string, cacheControlRule string) (CacheControlRulePair, error) { var err error pair := CacheControlRulePair{ - GlobPattern: globPattern, - CacheControlRule: cacheControlRule, + GlobPattern: globPattern, + CacheControlRule: cacheControlRule, } pair.CompiledGlobPattern, err = glob.Compile(globPattern) diff --git a/plank/pkg/middleware/prometheus_metrics.go b/plank/pkg/middleware/prometheus_metrics.go index 9fbbcee..0161246 100644 --- a/plank/pkg/middleware/prometheus_metrics.go +++ b/plank/pkg/middleware/prometheus_metrics.go @@ -1,8 +1,8 @@ // Copyright 2019-2021 VMware, Inc. // SPDX-License-Identifier: BSD-2-Clause -// +build !js -// +build !wasm +//go:build !js && !wasm +// +build !js,!wasm package middleware diff --git a/plank/pkg/server/prometheus.go b/plank/pkg/server/prometheus.go index 4f1f3e5..2680cd7 100644 --- a/plank/pkg/server/prometheus.go +++ b/plank/pkg/server/prometheus.go @@ -1,8 +1,8 @@ // Copyright 2019-2021 VMware, Inc. // SPDX-License-Identifier: BSD-2-Clause -// +build !js -// +build !wasm +//go:build !js && !wasm +// +build !js,!wasm package server diff --git a/plank/pkg/server/spa_config.go b/plank/pkg/server/spa_config.go index ca1b4a0..0d74e75 100644 --- a/plank/pkg/server/spa_config.go +++ b/plank/pkg/server/spa_config.go @@ -16,16 +16,16 @@ import ( // are served from /app/static, BaseUri can be set to /app and StaticAssets to "/app/assets". see config.json // for details. type SpaConfig struct { - RootFolder string `json:"root_folder"` // location where Plank will serve SPA - BaseUri string `json:"base_uri"` // base URI for the SPA - StaticAssets []string `json:"static_assets"` // locations for static assets used by the SPA + RootFolder string `json:"root_folder"` // location where Plank will serve SPA + BaseUri string `json:"base_uri"` // base URI for the SPA + StaticAssets []string `json:"static_assets"` // locations for static assets used by the SPA CacheControlRules map[string]string `json:"cache_control_rules"` // map holding glob pattern - cache-control header value cacheControlRulePairs []middleware.CacheControlRulePair } type regexCacheControlRulePair struct { - regex *regexp.Regexp + regex *regexp.Regexp cacheControlRule string } @@ -34,9 +34,9 @@ type regexCacheControlRulePair struct { func NewSpaConfig(input string) (spaConfig *SpaConfig, err error) { p, uri := utils.DeriveStaticURIFromPath(input) spaConfig = &SpaConfig{ - RootFolder: p, - BaseUri: uri, - CacheControlRules: make(map[string]string), + RootFolder: p, + BaseUri: uri, + CacheControlRules: make(map[string]string), cacheControlRulePairs: make([]middleware.CacheControlRulePair, 0), } @@ -71,4 +71,4 @@ func (s *SpaConfig) CacheControlMiddleware() mux.MiddlewareFunc { handler.ServeHTTP(w, r) }) } -} \ No newline at end of file +} diff --git a/plank/utils/console_helpers.go b/plank/utils/console_helpers.go index 5357ca4..fe3b20b 100644 --- a/plank/utils/console_helpers.go +++ b/plank/utils/console_helpers.go @@ -6,16 +6,16 @@ package utils import "github.com/fatih/color" var ( - InfoHeaderf = color.New(color.FgHiBlue).Add(color.Bold).PrintfFunc() - InfoHeaderFprintf = color.New(color.FgHiBlue).Add(color.Bold).FprintfFunc() - Infof = color.New(color.FgHiCyan).PrintfFunc() + InfoHeaderf = color.New(color.FgHiBlue).Add(color.Bold).PrintfFunc() + InfoHeaderFprintf = color.New(color.FgHiBlue).Add(color.Bold).FprintfFunc() + Infof = color.New(color.FgHiCyan).PrintfFunc() InfoFprintf = color.New(color.FgHiCyan).FprintfFunc() - WarnHeaderf = color.New(color.FgHiYellow).Add(color.Bold).PrintfFunc() + WarnHeaderf = color.New(color.FgHiYellow).Add(color.Bold).PrintfFunc() WarnHeaderFprintf = color.New(color.FgHiYellow).Add(color.Bold).FprintfFunc() - Warnf = color.New(color.FgHiYellow).PrintfFunc() + Warnf = color.New(color.FgHiYellow).PrintfFunc() WarnFprintf = color.New(color.FgHiYellow).FprintfFunc() - ErrorHeaderf = color.New(color.FgHiRed).Add(color.Bold).PrintfFunc() + ErrorHeaderf = color.New(color.FgHiRed).Add(color.Bold).PrintfFunc() ErrorHeaderFprintf = color.New(color.FgHiRed).Add(color.Bold).FprintfFunc() - Errorf = color.New(color.FgHiRed).PrintfFunc() + Errorf = color.New(color.FgHiRed).PrintfFunc() ErrorFprintf = color.New(color.FgHiRed).FprintfFunc() ) diff --git a/service/fabric_service.go b/service/fabric_service.go index cdb52d4..de666e1 100644 --- a/service/fabric_service.go +++ b/service/fabric_service.go @@ -4,18 +4,17 @@ package service import ( - "github.com/vmware/transport-go/model" + "github.com/vmware/transport-go/model" ) // FabricService Interface containing all APIs which should be implemented by Fabric Services. type FabricService interface { - // Handles a single Fabric Request - HandleServiceRequest(request *model.Request, core FabricServiceCore) + // Handles a single Fabric Request + HandleServiceRequest(request *model.Request, core FabricServiceCore) } // FabricInitializableService Optional interface, if implemented by a fabric service, its Init method // will be invoked when the service is registered in the ServiceRegistry. type FabricInitializableService interface { - Init(core FabricServiceCore) error + Init(core FabricServiceCore) error } - diff --git a/service/rest_service_test.go b/service/rest_service_test.go index 47cffd0..7af4309 100644 --- a/service/rest_service_test.go +++ b/service/rest_service_test.go @@ -4,45 +4,45 @@ package service import ( - "bytes" - "encoding/json" - "errors" - "github.com/stretchr/testify/assert" - "github.com/vmware/transport-go/model" - "io/ioutil" - "net/http" - "reflect" - "strings" - "sync" - "testing" + "bytes" + "encoding/json" + "errors" + "github.com/stretchr/testify/assert" + "github.com/vmware/transport-go/model" + "io/ioutil" + "net/http" + "reflect" + "strings" + "sync" + "testing" ) type testItem struct { - Name string `json:"name"` - Count int `json:"count"` + Name string `json:"name"` + Count int `json:"count"` } func TestRestServiceRequest_marshalBody(t *testing.T) { - reqWithStringBody := &RestServiceRequest{Body: "test-body"} - body, err := reqWithStringBody.marshalBody() - assert.Nil(t, err) - assert.Equal(t, []byte("test-body"), body) - - reqWithBytesBody := &RestServiceRequest{Body: []byte{1,2,3,4}} - body, err = reqWithBytesBody.marshalBody() - assert.Nil(t, err) - assert.Equal(t, reqWithBytesBody.Body, body) - - item := testItem{Name: "test-name", Count: 5} - reqWithTestItem := &RestServiceRequest{Body: item} - body, err = reqWithTestItem.marshalBody() - assert.Nil(t, err) - expectedValue, _ := json.Marshal(item) - assert.Equal(t, expectedValue, body) + reqWithStringBody := &RestServiceRequest{Body: "test-body"} + body, err := reqWithStringBody.marshalBody() + assert.Nil(t, err) + assert.Equal(t, []byte("test-body"), body) + + reqWithBytesBody := &RestServiceRequest{Body: []byte{1, 2, 3, 4}} + body, err = reqWithBytesBody.marshalBody() + assert.Nil(t, err) + assert.Equal(t, reqWithBytesBody.Body, body) + + item := testItem{Name: "test-name", Count: 5} + reqWithTestItem := &RestServiceRequest{Body: item} + body, err = reqWithTestItem.marshalBody() + assert.Nil(t, err) + expectedValue, _ := json.Marshal(item) + assert.Equal(t, expectedValue, body) } func TestRestService_AutoRegistration(t *testing.T) { - assert.NotNil(t, GetServiceRegistry().(*serviceRegistry).services[restServiceChannel]) + assert.NotNil(t, GetServiceRegistry().(*serviceRegistry).services[restServiceChannel]) } // RoundTripFunc . @@ -50,320 +50,314 @@ type RoundTripFunc func(req *http.Request) (*http.Response, error) // RoundTrip . func (f RoundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { - return f(req) + return f(req) } func TestRestService_HandleServiceRequest(t *testing.T) { - core := newTestFabricCore(restServiceChannel) - - restService := &restService{} - var lastHttpRequest *http.Request - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), - Header: make(http.Header), - }, nil - }) - - var lastResponse *model.Response - - wg := sync.WaitGroup{} - wg.Add(1) - - mh, _ := core.Bus().ListenStream(restServiceChannel) - mh.Handle( - func(message *model.Message) { - lastResponse = message.Payload.(*model.Response) - wg.Done() - }, - func(e error) { - assert.Fail(t, "unexpected error") - }) - - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - Headers: map[string]string{ "header1": "value1", "header2": "value2"}, - Method: "UPDATE", - Body: "test-body", - ResponseType: reflect.TypeOf(""), - }, - }, core) - - wg.Wait() - - assert.NotNil(t, lastHttpRequest) - assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") - assert.Equal(t, lastHttpRequest.Method, "UPDATE") - assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") - assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") - assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") - sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) - assert.Equal(t, sentBody, []byte("test-body")) - - assert.NotNil(t, lastResponse) - assert.Equal(t, lastResponse.Payload, "test-response-body") - assert.False(t, lastResponse.Error) - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString(`{"name": "test-name", "count": 2}`)), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - Headers: map[string]string {"Content-Type": "json"}, - ResponseType: reflect.TypeOf(testItem{}), - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "json") - assert.Equal(t, lastResponse.Payload, testItem{Name:"test-name", Count: 2}) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - ResponseType: reflect.TypeOf(&testItem{}), - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastResponse.Payload, &testItem{Name:"test-name", Count: 2}) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastResponse.Payload, map[string]interface{} {"name": "test-name", "count": float64(2)}) - - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBuffer([]byte{1,2,3,4,5})), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - ResponseType: reflect.TypeOf([]byte{}), - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastResponse.Payload, []byte{1,2,3,4,5}) + core := newTestFabricCore(restServiceChannel) + + restService := &restService{} + var lastHttpRequest *http.Request + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), + Header: make(http.Header), + }, nil + }) + + var lastResponse *model.Response + + wg := sync.WaitGroup{} + wg.Add(1) + + mh, _ := core.Bus().ListenStream(restServiceChannel) + mh.Handle( + func(message *model.Message) { + lastResponse = message.Payload.(*model.Response) + wg.Done() + }, + func(e error) { + assert.Fail(t, "unexpected error") + }) + + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + Headers: map[string]string{"header1": "value1", "header2": "value2"}, + Method: "UPDATE", + Body: "test-body", + ResponseType: reflect.TypeOf(""), + }, + }, core) + + wg.Wait() + + assert.NotNil(t, lastHttpRequest) + assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") + assert.Equal(t, lastHttpRequest.Method, "UPDATE") + assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") + assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") + assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") + sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) + assert.Equal(t, sentBody, []byte("test-body")) + + assert.NotNil(t, lastResponse) + assert.Equal(t, lastResponse.Payload, "test-response-body") + assert.False(t, lastResponse.Error) + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString(`{"name": "test-name", "count": 2}`)), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + Headers: map[string]string{"Content-Type": "json"}, + ResponseType: reflect.TypeOf(testItem{}), + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "json") + assert.Equal(t, lastResponse.Payload, testItem{Name: "test-name", Count: 2}) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + ResponseType: reflect.TypeOf(&testItem{}), + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastResponse.Payload, &testItem{Name: "test-name", Count: 2}) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastResponse.Payload, map[string]interface{}{"name": "test-name", "count": float64(2)}) + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBuffer([]byte{1, 2, 3, 4, 5})), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + ResponseType: reflect.TypeOf([]byte{}), + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastResponse.Payload, []byte{1, 2, 3, 4, 5}) } func TestRestService_HandleJavaServiceRequest(t *testing.T) { - core := newTestFabricCore(restServiceChannel) - - wg := sync.WaitGroup{} - - restService := &restService{} - var lastHttpRequest *http.Request - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - defer wg.Done() - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), - Header: make(http.Header), - }, nil - }) - - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: map[string]interface{} { - "uri": "http://localhost:4444/test-url", - "headers": map[string]string { "header1": "value1", "header2": "value2"}, - "method": "UPDATE", - "Body": "test-body", - "apiClass": "java.lang.String", - }, - }, core) - - wg.Wait() - - assert.NotNil(t, lastHttpRequest) - assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") - assert.Equal(t, lastHttpRequest.Method, "UPDATE") - assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") - assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") - assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") - sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) - assert.Equal(t, sentBody, []byte("test-body")) + core := newTestFabricCore(restServiceChannel) + + wg := sync.WaitGroup{} + + restService := &restService{} + var lastHttpRequest *http.Request + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + defer wg.Done() + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: map[string]interface{}{ + "uri": "http://localhost:4444/test-url", + "headers": map[string]string{"header1": "value1", "header2": "value2"}, + "method": "UPDATE", + "Body": "test-body", + "apiClass": "java.lang.String", + }, + }, core) + + wg.Wait() + + assert.NotNil(t, lastHttpRequest) + assert.Equal(t, lastHttpRequest.URL.String(), "http://localhost:4444/test-url") + assert.Equal(t, lastHttpRequest.Method, "UPDATE") + assert.Equal(t, lastHttpRequest.Header.Get("header1"), "value1") + assert.Equal(t, lastHttpRequest.Header.Get("header2"), "value2") + assert.Equal(t, lastHttpRequest.Header.Get("Content-Type"), "application/merge-patch+json") + sentBody, _ := ioutil.ReadAll(lastHttpRequest.Body) + assert.Equal(t, sentBody, []byte("test-body")) } func TestRestService_HandleServiceRequest_InvalidInput(t *testing.T) { - core := newTestFabricCore(restServiceChannel) - - restService := &restService{} - var lastResponse *model.Response - - wg := sync.WaitGroup{} - wg.Add(1) - mh, _ := core.Bus().ListenStream(restServiceChannel) - mh.Handle( - func(message *model.Message) { - lastResponse = message.Payload.(*model.Response) - wg.Done() - }, - func(e error) { - assert.Fail(t, "unexpected error") - }) - - restService.HandleServiceRequest(&model.Request{ - Payload: RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - Method: "UPDATE", - }, - }, core) - - wg.Wait() - - assert.NotNil(t, lastResponse) - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 500) - assert.Equal(t, lastResponse.ErrorMessage, "invalid RestServiceRequest payload") - - wg.Add(1) - - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - Method: "@!#$%^&**()", - }, - }, core) - - wg.Wait() - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 500) - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - return nil, errors.New("custom-rest-error") - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - wg.Wait() - - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 500) - assert.True(t, strings.Contains(lastResponse.ErrorMessage, "custom-rest-error")) - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 404, - Status: "404 Not Found", - Body: ioutil.NopCloser(bytes.NewBufferString("error-response")), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - wg.Wait() - - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 404) - assert.Equal(t, lastResponse.ErrorMessage, "rest-service error, unable to complete request: 404 Not Found") - - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString("}")), - Header: make(http.Header), - }, nil - }) - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - ResponseType: reflect.TypeOf(&testItem{}), - }, - }, core) - wg.Wait() - - assert.True(t, lastResponse.Error) - assert.Equal(t, lastResponse.ErrorCode, 500) - assert.True(t, strings.HasPrefix(lastResponse.ErrorMessage, "failed to deserialize response:")) + core := newTestFabricCore(restServiceChannel) + + restService := &restService{} + var lastResponse *model.Response + + wg := sync.WaitGroup{} + wg.Add(1) + mh, _ := core.Bus().ListenStream(restServiceChannel) + mh.Handle( + func(message *model.Message) { + lastResponse = message.Payload.(*model.Response) + wg.Done() + }, + func(e error) { + assert.Fail(t, "unexpected error") + }) + + restService.HandleServiceRequest(&model.Request{ + Payload: RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + Method: "UPDATE", + }, + }, core) + + wg.Wait() + + assert.NotNil(t, lastResponse) + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 500) + assert.Equal(t, lastResponse.ErrorMessage, "invalid RestServiceRequest payload") + + wg.Add(1) + + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + Method: "@!#$%^&**()", + }, + }, core) + + wg.Wait() + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 500) + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + return nil, errors.New("custom-rest-error") + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + wg.Wait() + + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 500) + assert.True(t, strings.Contains(lastResponse.ErrorMessage, "custom-rest-error")) + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 404, + Status: "404 Not Found", + Body: ioutil.NopCloser(bytes.NewBufferString("error-response")), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + wg.Wait() + + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 404) + assert.Equal(t, lastResponse.ErrorMessage, "rest-service error, unable to complete request: 404 Not Found") + + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("}")), + Header: make(http.Header), + }, nil + }) + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + ResponseType: reflect.TypeOf(&testItem{}), + }, + }, core) + wg.Wait() + + assert.True(t, lastResponse.Error) + assert.Equal(t, lastResponse.ErrorCode, 500) + assert.True(t, strings.HasPrefix(lastResponse.ErrorMessage, "failed to deserialize response:")) } func TestRestService_setBaseHost(t *testing.T) { - core := newTestFabricCore(restServiceChannel) - restService := &restService{} - - wg := sync.WaitGroup{} - - - var lastHttpRequest *http.Request - restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { - lastHttpRequest = req - wg.Done() - return &http.Response{ - StatusCode: 200, - Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), - Header: make(http.Header), - }, nil - }) - - restService.setBaseHost("appfabric.vmware.com:9999") - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - - wg.Wait() - - assert.Equal(t, lastHttpRequest.Host, "appfabric.vmware.com:9999") - - - restService.setBaseHost("") - - wg.Add(1) - restService.HandleServiceRequest(&model.Request{ - Payload: &RestServiceRequest{ - Uri: "http://localhost:4444/test-url", - }, - }, core) - wg.Wait() - - assert.Equal(t, lastHttpRequest.Host, "localhost:4444") + core := newTestFabricCore(restServiceChannel) + restService := &restService{} + + wg := sync.WaitGroup{} + + var lastHttpRequest *http.Request + restService.httpClient.Transport = RoundTripFunc(func(req *http.Request) (*http.Response, error) { + lastHttpRequest = req + wg.Done() + return &http.Response{ + StatusCode: 200, + Body: ioutil.NopCloser(bytes.NewBufferString("test-response-body")), + Header: make(http.Header), + }, nil + }) + + restService.setBaseHost("appfabric.vmware.com:9999") + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + + wg.Wait() + + assert.Equal(t, lastHttpRequest.Host, "appfabric.vmware.com:9999") + + restService.setBaseHost("") + + wg.Add(1) + restService.HandleServiceRequest(&model.Request{ + Payload: &RestServiceRequest{ + Uri: "http://localhost:4444/test-url", + }, + }, core) + wg.Wait() + + assert.Equal(t, lastHttpRequest.Host, "localhost:4444") } - - diff --git a/service/service_lifecycle_manager.go b/service/service_lifecycle_manager.go index 5a91bb9..792f8fd 100644 --- a/service/service_lifecycle_manager.go +++ b/service/service_lifecycle_manager.go @@ -15,23 +15,23 @@ type ServiceLifecycleManager interface { } type ServiceLifecycleHookEnabled interface { - OnServiceReady() chan bool // service initialization logic should be implemented here + OnServiceReady() chan bool // service initialization logic should be implemented here OnServerShutdown() // teardown logic goes here and will be automatically invoked on graceful server shutdown GetRESTBridgeConfig() []*RESTBridgeConfig // service-to-REST endpoint mappings go here } type SetupRESTBridgeRequest struct { ServiceChannel string - Override bool - Config []*RESTBridgeConfig + Override bool + Config []*RESTBridgeConfig } type RESTBridgeConfig struct { - ServiceChannel string // transport service channel - Uri string // URI to map the transport service to - Method string // HTTP verb to map the transport service request to URI with - AllowHead bool // whether HEAD calls are allowed for this bridge point - AllowOptions bool // whether OPTIONS calls are allowed for this bridge point + ServiceChannel string // transport service channel + Uri string // URI to map the transport service to + Method string // HTTP verb to map the transport service request to URI with + AllowHead bool // whether HEAD calls are allowed for this bridge point + AllowOptions bool // whether OPTIONS calls are allowed for this bridge point FabricRequestBuilder RequestBuilder // function to transform HTTP request into a transport request } @@ -82,4 +82,4 @@ func GetServiceLifecycleManager() ServiceLifecycleManager { // newServiceLifecycleManager returns a new instance of ServiceLifecycleManager func newServiceLifecycleManager(reg ServiceRegistry) ServiceLifecycleManager { return &serviceLifecycleManager{serviceRegistryRef: reg} -} \ No newline at end of file +} diff --git a/service/service_lifecycle_manager_test.go b/service/service_lifecycle_manager_test.go index 285e151..c200dcb 100644 --- a/service/service_lifecycle_manager_test.go +++ b/service/service_lifecycle_manager_test.go @@ -29,7 +29,6 @@ func TestServiceLifecycleManager_GetServiceHooks(t *testing.T) { // act hooks := lcm.GetServiceHooks("another-test-channel") - // assert assert.NotNil(t, hooks) } @@ -114,4 +113,4 @@ func TestServiceLifecycleManager_OverrideRESTBridgeConfig(t *testing.T) { err = lcm.OverrideRESTBridgeConfig("another-test-channel", []*RESTBridgeConfig{payload}) assert.Nil(t, err) wg.Wait() -} \ No newline at end of file +} diff --git a/service/service_registry.go b/service/service_registry.go index 9753d3d..018fbb8 100644 --- a/service/service_registry.go +++ b/service/service_registry.go @@ -20,7 +20,7 @@ const ( LifecycleManagerChannelName = bus.TRANSPORT_INTERNAL_CHANNEL_PREFIX + "service-lifecycle-manager" // store constants - ServiceReadyStore = "service-ready-notification-store" + ServiceReadyStore = "service-ready-notification-store" ServiceInitStateChange = "service-init-state-change" ) @@ -46,9 +46,9 @@ type ServiceRegistry interface { } type serviceRegistry struct { - lock sync.Mutex - services map[string]*fabricServiceWrapper - bus bus.EventBus + lock sync.Mutex + services map[string]*fabricServiceWrapper + bus bus.EventBus lifecycleManager *serviceLifecycleManager } diff --git a/service/service_registry_test.go b/service/service_registry_test.go index 16c1bcd..e242d09 100644 --- a/service/service_registry_test.go +++ b/service/service_registry_test.go @@ -4,214 +4,214 @@ package service import ( - "errors" - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/vmware/transport-go/bus" - "github.com/vmware/transport-go/model" - "net/http" - "sync" - "testing" + "errors" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/vmware/transport-go/bus" + "github.com/vmware/transport-go/model" + "net/http" + "sync" + "testing" ) func newTestServiceRegistry() *serviceRegistry { - eventBus := bus.NewEventBusInstance() - return newServiceRegistry(eventBus).(*serviceRegistry) + eventBus := bus.NewEventBusInstance() + return newServiceRegistry(eventBus).(*serviceRegistry) } func newTestServiceLifecycleManager(sr ServiceRegistry) ServiceLifecycleManager { - return newServiceLifecycleManager(sr) + return newServiceLifecycleManager(sr) } type mockFabricService struct { - processedRequests []*model.Request - core FabricServiceCore - wg sync.WaitGroup + processedRequests []*model.Request + core FabricServiceCore + wg sync.WaitGroup } func (fs *mockFabricService) HandleServiceRequest(request *model.Request, core FabricServiceCore) { - fs.processedRequests = append(fs.processedRequests, request) - fs.core = core - fs.wg.Done() + fs.processedRequests = append(fs.processedRequests, request) + fs.core = core + fs.wg.Done() } type mockLifecycleHookEnabledService struct { - initChan chan bool - core FabricServiceCore - shutdown bool + initChan chan bool + core FabricServiceCore + shutdown bool } func (s *mockLifecycleHookEnabledService) HandleServiceRequest(request *model.Request, core FabricServiceCore) { } func (s *mockLifecycleHookEnabledService) OnServiceReady() chan bool { - s.initChan = make(chan bool, 1) - s.initChan <- true - return s.initChan + s.initChan = make(chan bool, 1) + s.initChan <- true + return s.initChan } func (s *mockLifecycleHookEnabledService) OnServerShutdown() { - s.shutdown = true + s.shutdown = true } func (s *mockLifecycleHookEnabledService) GetRESTBridgeConfig() []*RESTBridgeConfig { - return []*RESTBridgeConfig{ - { - ServiceChannel: "another-test-channel", - Uri: "/rest/test", - Method: http.MethodGet, - AllowHead: true, - AllowOptions: true, - FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { - return model.Request{ - Id: &uuid.UUID{}, - Payload: "test", - } - }, - }, - } + return []*RESTBridgeConfig{ + { + ServiceChannel: "another-test-channel", + Uri: "/rest/test", + Method: http.MethodGet, + AllowHead: true, + AllowOptions: true, + FabricRequestBuilder: func(w http.ResponseWriter, r *http.Request) model.Request { + return model.Request{ + Id: &uuid.UUID{}, + Payload: "test", + } + }, + }, + } } type mockInitializableService struct { - initialized bool - core FabricServiceCore - initError error + initialized bool + core FabricServiceCore + initError error } func (fs *mockInitializableService) Init(core FabricServiceCore) error { - fs.core = core - fs.initialized = true - return fs.initError + fs.core = core + fs.initialized = true + return fs.initError } func (fs *mockInitializableService) HandleServiceRequest(request *model.Request, core FabricServiceCore) { } func TestGetServiceRegistry(t *testing.T) { - sr := GetServiceRegistry() - sr2 := GetServiceRegistry() - assert.NotNil(t, sr) - assert.Equal(t, sr, sr2) + sr := GetServiceRegistry() + sr2 := GetServiceRegistry() + assert.NotNil(t, sr) + assert.Equal(t, sr, sr2) } func TestServiceRegistry_RegisterService(t *testing.T) { - registry := newTestServiceRegistry() - mockService := &mockFabricService{} + registry := newTestServiceRegistry() + mockService := &mockFabricService{} - assert.Nil(t, registry.RegisterService(mockService, "test-channel")) - assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel")) + assert.Nil(t, registry.RegisterService(mockService, "test-channel")) + assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel")) - id := uuid.New() - req := model.Request{ - Id: &id, - Request: "test-request", - Payload: "request-payload", - } + id := uuid.New() + req := model.Request{ + Id: &id, + Request: "test-request", + Payload: "request-payload", + } - mockService.wg.Add(1) - registry.bus.SendRequestMessage("test-channel", req, nil) - mockService.wg.Wait() + mockService.wg.Add(1) + registry.bus.SendRequestMessage("test-channel", req, nil) + mockService.wg.Wait() - assert.Equal(t, len(mockService.processedRequests), 1) - assert.Equal(t, *mockService.processedRequests[0], req) - assert.NotNil(t, mockService.core) + assert.Equal(t, len(mockService.processedRequests), 1) + assert.Equal(t, *mockService.processedRequests[0], req) + assert.NotNil(t, mockService.core) - registry.bus.SendRequestMessage("test-channel", "invalid-request", nil) - registry.bus.SendRequestMessage("test-channel", nil, nil) - registry.bus.SendResponseMessage("test-channel", req, nil) - registry.bus.SendErrorMessage("test-channel", errors.New("test-error"), nil) + registry.bus.SendRequestMessage("test-channel", "invalid-request", nil) + registry.bus.SendRequestMessage("test-channel", nil, nil) + registry.bus.SendResponseMessage("test-channel", req, nil) + registry.bus.SendErrorMessage("test-channel", errors.New("test-error"), nil) - mockService.wg.Add(1) - registry.bus.SendRequestMessage("test-channel", &req, nil) - mockService.wg.Wait() + mockService.wg.Add(1) + registry.bus.SendRequestMessage("test-channel", &req, nil) + mockService.wg.Wait() - assert.Equal(t, len(mockService.processedRequests), 2) - assert.Equal(t, mockService.processedRequests[1], &req) - assert.NotNil(t, mockService.core) + assert.Equal(t, len(mockService.processedRequests), 2) + assert.Equal(t, mockService.processedRequests[1], &req) + assert.NotNil(t, mockService.core) - mockService.wg.Add(1) - uuid := uuid.New() - registry.bus.SendRequestMessage("test-channel", model.Request{ - Request: "test-request-2", - Payload: "request-payload", - }, &uuid) - mockService.wg.Wait() + mockService.wg.Add(1) + uuid := uuid.New() + registry.bus.SendRequestMessage("test-channel", model.Request{ + Request: "test-request-2", + Payload: "request-payload", + }, &uuid) + mockService.wg.Wait() - assert.Equal(t, len(mockService.processedRequests), 3) - assert.Equal(t, mockService.processedRequests[2].Id, &uuid) + assert.Equal(t, len(mockService.processedRequests), 3) + assert.Equal(t, mockService.processedRequests[2].Id, &uuid) - assert.EqualError(t, registry.RegisterService(&mockFabricService{}, "test-channel"), - "unable to register service: service channel name is already used: test-channel") + assert.EqualError(t, registry.RegisterService(&mockFabricService{}, "test-channel"), + "unable to register service: service channel name is already used: test-channel") - assert.EqualError(t, registry.RegisterService(nil, "test-channel2"), - "unable to register service: nil service") + assert.EqualError(t, registry.RegisterService(nil, "test-channel2"), + "unable to register service: nil service") - assert.False(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel2")) + assert.False(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel2")) } func TestServiceRegistry_RegisterInitializableService(t *testing.T) { - registry := newTestServiceRegistry() - mockService := &mockInitializableService{} - assert.Nil(t, registry.RegisterService(mockService, "test-channel")) + registry := newTestServiceRegistry() + mockService := &mockInitializableService{} + assert.Nil(t, registry.RegisterService(mockService, "test-channel")) - assert.True(t, mockService.initialized) - assert.NotNil(t, mockService.core) + assert.True(t, mockService.initialized) + assert.NotNil(t, mockService.core) - assert.EqualError(t, - registry.RegisterService(&mockInitializableService{initError: errors.New("init-error")}, "test-channel2"), - "init-error") + assert.EqualError(t, + registry.RegisterService(&mockInitializableService{initError: errors.New("init-error")}, "test-channel2"), + "init-error") } func TestServiceRegistry_UnregisterService(t *testing.T) { - registry := newTestServiceRegistry() - mockService := &mockFabricService{} + registry := newTestServiceRegistry() + mockService := &mockFabricService{} - assert.Nil(t, registry.RegisterService(mockService, "test-channel")) - assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel")) + assert.Nil(t, registry.RegisterService(mockService, "test-channel")) + assert.True(t, registry.bus.GetChannelManager().CheckChannelExists("test-channel")) - id := uuid.New() - req := model.Request{ - Id: &id, - Request: "test-request", - Payload: "request-payload", - } + id := uuid.New() + req := model.Request{ + Id: &id, + Request: "test-request", + Payload: "request-payload", + } - assert.Nil(t, registry.UnregisterService("test-channel")) - registry.bus.SendRequestMessage("test-channel", req, nil) + assert.Nil(t, registry.UnregisterService("test-channel")) + registry.bus.SendRequestMessage("test-channel", req, nil) - assert.Equal(t, len(mockService.processedRequests), 0) - assert.EqualError(t, registry.UnregisterService("test-channel"), - "unable to unregister service: no service is registered for channel \"test-channel\"") + assert.Equal(t, len(mockService.processedRequests), 0) + assert.EqualError(t, registry.UnregisterService("test-channel"), + "unable to unregister service: no service is registered for channel \"test-channel\"") } func TestServiceRegistry_SetGlobalRestServiceBaseHost(t *testing.T) { - registry := newTestServiceRegistry() - registry.SetGlobalRestServiceBaseHost("localhost:9999") - assert.Equal(t, "localhost:9999", - registry.services[restServiceChannel].service.(*restService).baseHost) + registry := newTestServiceRegistry() + registry.SetGlobalRestServiceBaseHost("localhost:9999") + assert.Equal(t, "localhost:9999", + registry.services[restServiceChannel].service.(*restService).baseHost) } func TestServiceRegistry_GetAllServiceChannels(t *testing.T) { - registry := newTestServiceRegistry() - mockService := &mockFabricService{} + registry := newTestServiceRegistry() + mockService := &mockFabricService{} - registry.RegisterService(mockService, "test-channel") - chans := registry.GetAllServiceChannels() + registry.RegisterService(mockService, "test-channel") + chans := registry.GetAllServiceChannels() - assert.Len(t, chans, 1) - assert.EqualValues(t, "test-channel", chans[0]) + assert.Len(t, chans, 1) + assert.EqualValues(t, "test-channel", chans[0]) } func TestServiceRegistry_RegisterService_LifecycleHookEnabled(t *testing.T) { - svc := &mockLifecycleHookEnabledService{} - registry := newTestServiceRegistry() - registry.RegisterService(svc, "another-test-channel") + svc := &mockLifecycleHookEnabledService{} + registry := newTestServiceRegistry() + registry.RegisterService(svc, "another-test-channel") - assert.True(t, <-svc.OnServiceReady()) + assert.True(t, <-svc.OnServiceReady()) - svc.OnServerShutdown() - assert.True(t, svc.shutdown) + svc.OnServerShutdown() + assert.True(t, svc.shutdown) - restBridgeConfig := svc.GetRESTBridgeConfig() - assert.NotNil(t, restBridgeConfig) -} \ No newline at end of file + restBridgeConfig := svc.GetRESTBridgeConfig() + assert.NotNil(t, restBridgeConfig) +} diff --git a/stompserver/config.go b/stompserver/config.go index 884ce1e..c871c9f 100644 --- a/stompserver/config.go +++ b/stompserver/config.go @@ -6,50 +6,45 @@ package stompserver import "strings" type StompConfig interface { - HeartBeat() int64 - AppDestinationPrefix() []string - IsAppRequestDestination(destination string) bool + HeartBeat() int64 + AppDestinationPrefix() []string + IsAppRequestDestination(destination string) bool } type stompConfig struct { - heartbeat int64 - appDestPrefix []string + heartbeat int64 + appDestPrefix []string } func NewStompConfig(heartBeatMs int64, appDestinationPrefix []string) StompConfig { - prefixes := make([]string, len(appDestinationPrefix)) - for i := 0; i < len(appDestinationPrefix); i++ { - if appDestinationPrefix[i] != "" && !strings.HasSuffix(appDestinationPrefix[i], "/") { - prefixes[i] = appDestinationPrefix[i] + "/" - } else { - prefixes[i] = appDestinationPrefix[i] - } - } - - return &stompConfig{ - heartbeat: heartBeatMs, - appDestPrefix: prefixes, - } + prefixes := make([]string, len(appDestinationPrefix)) + for i := 0; i < len(appDestinationPrefix); i++ { + if appDestinationPrefix[i] != "" && !strings.HasSuffix(appDestinationPrefix[i], "/") { + prefixes[i] = appDestinationPrefix[i] + "/" + } else { + prefixes[i] = appDestinationPrefix[i] + } + } + + return &stompConfig{ + heartbeat: heartBeatMs, + appDestPrefix: prefixes, + } } func (c *stompConfig) HeartBeat() int64 { - return c.heartbeat + return c.heartbeat } func (c *stompConfig) AppDestinationPrefix() []string { - return c.appDestPrefix + return c.appDestPrefix } func (c *stompConfig) IsAppRequestDestination(destination string) bool { - for _, prefix := range c.appDestPrefix { - if prefix != "" && strings.HasPrefix(destination, prefix) { - return true - } - } - return false + for _, prefix := range c.appDestPrefix { + if prefix != "" && strings.HasPrefix(destination, prefix) { + return true + } + } + return false } - - - - - diff --git a/stompserver/errors.go b/stompserver/errors.go index ab8b875..1fd992a 100644 --- a/stompserver/errors.go +++ b/stompserver/errors.go @@ -4,18 +4,18 @@ package stompserver const ( - notConnectedStompError = stompErrorMessage("not connected") - unexpectedStompCommandError = stompErrorMessage("unexpected frame command") - unsupportedStompCommandError = stompErrorMessage("unsupported command") - unsupportedStompVersionError = stompErrorMessage("unsupported STOMP version") - invalidSubscriptionError = stompErrorMessage("invalid subscription") - invalidFrameError = stompErrorMessage("invalid frame") - invalidHeaderError = stompErrorMessage("invalid frame header") - invalidSendDestinationError = stompErrorMessage("invalid send destination") + notConnectedStompError = stompErrorMessage("not connected") + unexpectedStompCommandError = stompErrorMessage("unexpected frame command") + unsupportedStompCommandError = stompErrorMessage("unsupported command") + unsupportedStompVersionError = stompErrorMessage("unsupported STOMP version") + invalidSubscriptionError = stompErrorMessage("invalid subscription") + invalidFrameError = stompErrorMessage("invalid frame") + invalidHeaderError = stompErrorMessage("invalid frame header") + invalidSendDestinationError = stompErrorMessage("invalid send destination") ) type stompErrorMessage string func (e stompErrorMessage) Error() string { - return string(e) + return string(e) } From e47c0c53460ce9f258f7d6061b691709f38abdbb Mon Sep 17 00:00:00 2001 From: Josh Kim Date: Sat, 4 Dec 2021 12:19:32 -0800 Subject: [PATCH 2/2] fix: UUID comparison logic bug UUID is used throughout the entire codebase of Transport and is especially important in filtering messages based on their destination. While implementing basic WASM bridge for the bus I realized that UUID comparison was failing, i.e. returning true for the equality check between two dififerent UUID instances. Turns out the use of .ID() method on the UUID object only returns the first 4 bytes of the underlying 16-byte slice. This means as long as the first 8 hexadecimal characters matched between two UUID instances, they would come out as equal, and it would therefore mean a significantly higher chance of UUID collision.. The fix introduced in this PR is to use the string comparison of the full UUID in places where .ID() was used. Signed-off-by: Josh Kim --- bridge/broker_connector_test.go | 2 +- bus/channel.go | 6 +++--- bus/eventbus.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bridge/broker_connector_test.go b/bridge/broker_connector_test.go index 0d84837..82f4e98 100644 --- a/bridge/broker_connector_test.go +++ b/bridge/broker_connector_test.go @@ -279,7 +279,7 @@ func TestBrokerConnector_Subscribe(t *testing.T) { // check re-subscribe returns same sub s2, _ := c.Subscribe("/topic/test") - assert.Equal(t, s.GetId().ID(), s2.GetId().ID()) + assert.Equal(t, s.GetId().String(), s2.GetId().String()) c.Disconnect() }) diff --git a/bus/channel.go b/bus/channel.go index 5c9f798..33ff4d5 100644 --- a/bus/channel.go +++ b/bus/channel.go @@ -112,7 +112,7 @@ func (channel *Channel) unsubscribeHandler(uuid *uuid.UUID) bool { defer channel.channelLock.Unlock() for i, handler := range channel.eventHandlers { - if handler.uuid.ID() == uuid.ID() { + if handler.uuid.String() == uuid.String() { channel.removeEventHandler(i) return true } @@ -152,7 +152,7 @@ func (channel *Channel) isBrokerSubscribed(sub bridge.Subscription) bool { defer channel.channelLock.Unlock() for _, cs := range channel.brokerSubs { - if sub.GetId().ID() == cs.s.GetId().ID() { + if sub.GetId().String() == cs.s.GetId().String() { return true } } @@ -206,7 +206,7 @@ func (channel *Channel) removeBrokerSubscription(sub bridge.Subscription) { defer channel.channelLock.Unlock() for i, cs := range channel.brokerSubs { - if sub.GetId().ID() == cs.s.GetId().ID() { + if sub.GetId().String() == cs.s.GetId().String() { channel.brokerSubs = removeSub(channel.brokerSubs, i) } } diff --git a/bus/eventbus.go b/bus/eventbus.go index d775f10..744641b 100644 --- a/bus/eventbus.go +++ b/bus/eventbus.go @@ -556,7 +556,7 @@ func (bus *transportEventBus) wrapMessageHandler( if msg.Direction == dir { // if we're checking for specific traffic, check a DestinationId match is required. if !messageHandler.ignoreId && - (msg.DestinationId != nil && id != nil) && (id.ID() == msg.DestinationId.ID()) { + (msg.DestinationId != nil && id != nil) && (id.String() == msg.DestinationId.String()) { successHandler(msg) } if messageHandler.ignoreId {