diff --git a/be1-go/hub/standard_hub/hub_state/ThreadSafeSlice.go b/be1-go/hub/standard_hub/hub_state/ThreadSafeSlice.go new file mode 100644 index 0000000000..8da635cf0f --- /dev/null +++ b/be1-go/hub/standard_hub/hub_state/ThreadSafeSlice.go @@ -0,0 +1,29 @@ +package hub_state + +import ( + "golang.org/x/exp/slices" + "sync" +) + +type ThreadSafeSlice[E comparable] struct { + sync.RWMutex + els []E +} + +func NewThreadSafeSlice[E comparable]() ThreadSafeSlice[E] { + return ThreadSafeSlice[E]{ + els: make([]E, 0), + } +} + +func (i *ThreadSafeSlice[E]) Append(elems ...E) { + i.Lock() + defer i.Unlock() + i.els = append(i.els, elems...) +} + +func (i *ThreadSafeSlice[E]) Contains(elem E) bool { + i.RLock() + defer i.RUnlock() + return slices.Contains(i.els, elem) +} diff --git a/be1-go/hub/standard_hub/message_handling.go b/be1-go/hub/standard_hub/message_handling.go index cad06ebe79..36ae998012 100644 --- a/be1-go/hub/standard_hub/message_handling.go +++ b/be1-go/hub/standard_hub/message_handling.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "popstellar/crypto" + "popstellar/hub/standard_hub/hub_state" jsonrpc "popstellar/message" "popstellar/message/answer" "popstellar/message/messagedata" @@ -210,7 +211,7 @@ func (h *Hub) handleGetMessagesByIdAnswer(senderSocket socket.Socket, answerMsg } } // Add contents from tempBlacklist to h.blacklist - h.blacklist = append(h.blacklist, tempBlacklist...) + h.blacklist.Append(tempBlacklist...) return xerrors.Errorf("failed to process messages: %v", err) } @@ -412,7 +413,7 @@ func (h *Hub) handleHeartbeat(socket socket.Socket, receivedIds := heartbeat.Params - missingIds := getMissingIds(receivedIds, h.hubInbox.GetIDsTable(), h.blacklist) + missingIds := getMissingIds(receivedIds, h.hubInbox.GetIDsTable(), &h.blacklist) if len(missingIds) > 0 { err = h.sendGetMessagesByIdToServer(socket, missingIds) @@ -468,11 +469,11 @@ func (h *Hub) handleGreetServer(socket socket.Socket, byteMessage []byte) error // getMissingIds compares two maps of channel Ids associated to slices of message Ids to // determine the missing Ids from the storedIds map with respect to the receivedIds map -func getMissingIds(receivedIds map[string][]string, storedIds map[string][]string, blacklist []string) map[string][]string { +func getMissingIds(receivedIds map[string][]string, storedIds map[string][]string, blacklist *hub_state.ThreadSafeSlice[string]) map[string][]string { missingIds := make(map[string][]string) for channelId, receivedMessageIds := range receivedIds { for _, messageId := range receivedMessageIds { - blacklisted := slices.Contains(blacklist, messageId) + blacklisted := blacklist.Contains(messageId) storedIdsForChannel, channelKnown := storedIds[channelId] if blacklisted { break @@ -580,7 +581,7 @@ func (h *Hub) loopOverMessages(messages *map[string][]json.RawMessage, senderSoc continue } - if slices.Contains(h.blacklist, messageData.MessageID) { + if h.blacklist.Contains(messageData.MessageID) { break } diff --git a/be1-go/hub/standard_hub/mod.go b/be1-go/hub/standard_hub/mod.go index 1567cb7724..e4872c3bd6 100644 --- a/be1-go/hub/standard_hub/mod.go +++ b/be1-go/hub/standard_hub/mod.go @@ -92,7 +92,7 @@ type Hub struct { // the server will not ask for them again in the heartbeat // and will not process them if they are received again // @TODO remove the messages from the blacklist after a certain amount of time by trying to process them again - blacklist []string + blacklist state.ThreadSafeSlice[string] } // NewHub returns a new Hub. @@ -126,7 +126,7 @@ func NewHub(pubKeyOwner kyber.Point, clientServerAddress string, serverServerAdd hubInbox: *inbox.NewHubInbox(rootChannel), queries: state.NewQueries(), peers: state.NewPeers(), - blacklist: make([]string, 0), + blacklist: state.NewThreadSafeSlice[string](), } return &hub, nil