From 2f2ce6afaa67d6d013bf7d785ed88557e8a9cc21 Mon Sep 17 00:00:00 2001 From: Bret <787344+bretep@users.noreply.github.com> Date: Mon, 10 Jun 2024 18:56:33 -0700 Subject: [PATCH] Refactor logsFilter to prevent concurrent map fatal errors (#10672) #### Issue: At line 129 in `logsfilter.go`, we had the following line of code: ```go _, addrOk := filter.addrs[gointerfaces.ConvertH160toAddress(eventLog.Address)] ``` This line caused a panic due to a fatal error: ```logs fatal error: concurrent map read and map write goroutine 106 [running]: github.com/ledgerwatch/erigon/turbo/rpchelper.(*LogsFilterAggregator).distributeLog.func1({0xc009701db8?, 0x8?}, 0xc135d26050) github.com/ledgerwatch/erigon/turbo/rpchelper/logsfilter.go:129 +0xe7 github.com/ledgerwatch/erigon/turbo/rpchelper.(*SyncMap[...]).Range(0xc009701eb0?, 0xc009701e70?) github.com/ledgerwatch/erigon/turbo/rpchelper/subscription.go:97 +0x11a github.com/ledgerwatch/erigon/turbo/rpchelper.(*LogsFilterAggregator).distributeLog(0x25f4600?, 0xc0000ce090?) github.com/ledgerwatch/erigon/turbo/rpchelper/logsfilter.go:131 +0xc7 github.com/ledgerwatch/erigon/turbo/rpchelper.(*Filters).OnNewLogs(...) github.com/ledgerwatch/erigon/turbo/rpchelper/filters.go:547 github.com/ledgerwatch/erigon/cmd/rpcdaemon/rpcservices.(*RemoteBackend).SubscribeLogs(0xc0019c2f50, {0x32f0040, 0xc001b4a280}, 0xc001c0c0e0, 0x0?) github.com/ledgerwatch/erigon/cmd/rpcdaemon/rpcservices/eth_backend.go:227 +0x1d1 github.com/ledgerwatch/erigon/turbo/rpchelper.New.func2() github.com/ledgerwatch/erigon/turbo/rpchelper/filters.go:102 +0xec created by github.com/ledgerwatch/erigon/turbo/rpchelper.New github.com/ledgerwatch/erigon/turbo/rpchelper/filters.go:92 +0x652 ``` This error indicates that there were simultaneous read and write operations on the `filter.addrs` map, leading to a race condition. #### Solution: To resolve this issue, I implemented the following changes: - Moved SyncMap to erigon-lib common library: This allows us to utilize a thread-safe map across different packages that require synchronized map access. - Refactored logsFilter to use SyncMap: By replacing the standard map with SyncMap, we ensured that all map operations are thread-safe, thus preventing concurrent read and write errors. - Added documentation for SyncMap usage: Detailed documentation was provided to guide the usage of SyncMap and related refactored components, ensuring clarity and proper utilization. --- erigon-lib/common/concurrent/concurrent.go | 79 ++++++++++++ turbo/rpchelper/filters.go | 120 ++++++++++++++----- turbo/rpchelper/filters_test.go | 20 ++-- turbo/rpchelper/logsfilter.go | 133 +++++++++++++++------ turbo/rpchelper/subscription.go | 66 ---------- 5 files changed, 276 insertions(+), 142 deletions(-) create mode 100644 erigon-lib/common/concurrent/concurrent.go diff --git a/erigon-lib/common/concurrent/concurrent.go b/erigon-lib/common/concurrent/concurrent.go new file mode 100644 index 00000000000..a29301ea79b --- /dev/null +++ b/erigon-lib/common/concurrent/concurrent.go @@ -0,0 +1,79 @@ +package concurrent + +import "sync" + +// NewSyncMap initializes and returns a new instance of SyncMap. +func NewSyncMap[K comparable, T any]() *SyncMap[K, T] { + return &SyncMap[K, T]{ + m: make(map[K]T), + } +} + +// SyncMap is a generic map that uses a read-write mutex to ensure thread-safe access. +type SyncMap[K comparable, T any] struct { + m map[K]T + mu sync.RWMutex +} + +// Get retrieves the value associated with the given key. +func (m *SyncMap[K, T]) Get(k K) (res T, ok bool) { + m.mu.RLock() + defer m.mu.RUnlock() + res, ok = m.m[k] + return res, ok +} + +// Put sets the value for the given key, returning the previous value if present. +func (m *SyncMap[K, T]) Put(k K, v T) (T, bool) { + m.mu.Lock() + defer m.mu.Unlock() + old, ok := m.m[k] + m.m[k] = v + return old, ok +} + +// Do performs a custom operation on the value associated with the given key. +func (m *SyncMap[K, T]) Do(k K, fn func(T, bool) (T, bool)) (after T, ok bool) { + m.mu.Lock() + defer m.mu.Unlock() + val, ok := m.m[k] + nv, save := fn(val, ok) + if save { + m.m[k] = nv + } else { + delete(m.m, k) + } + return nv, ok +} + +// DoAndStore performs a custom operation on the value associated with the given key and stores the result. +func (m *SyncMap[K, T]) DoAndStore(k K, fn func(t T, ok bool) T) (after T, ok bool) { + return m.Do(k, func(t T, b bool) (T, bool) { + res := fn(t, b) + return res, true + }) +} + +// Range calls a function for each key-value pair in the map. +func (m *SyncMap[K, T]) Range(fn func(k K, v T) error) error { + m.mu.RLock() + defer m.mu.RUnlock() + for k, v := range m.m { + if err := fn(k, v); err != nil { + return err + } + } + return nil +} + +// Delete removes the value associated with the given key, if present. +func (m *SyncMap[K, T]) Delete(k K) (t T, deleted bool) { + m.mu.Lock() + defer m.mu.Unlock() + val, ok := m.m[k] + if !ok { + return t, false + } + delete(m.m, k) + return val, true +} diff --git a/turbo/rpchelper/filters.go b/turbo/rpchelper/filters.go index 1d99858cd15..cf8a80eaf0d 100644 --- a/turbo/rpchelper/filters.go +++ b/turbo/rpchelper/filters.go @@ -13,6 +13,7 @@ import ( "time" libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon-lib/common/concurrent" "github.com/ledgerwatch/erigon-lib/gointerfaces" "github.com/ledgerwatch/erigon-lib/gointerfaces/grpcutil" remote "github.com/ledgerwatch/erigon-lib/gointerfaces/remoteproto" @@ -26,39 +27,45 @@ import ( "github.com/ledgerwatch/erigon/rlp" ) +// Filters holds the state for managing subscriptions to various Ethereum events. +// It allows for the subscription and management of events such as new blocks, pending transactions, +// logs, and other Ethereum-related activities. type Filters struct { mu sync.RWMutex pendingBlock *types.Block - headsSubs *SyncMap[HeadsSubID, Sub[*types.Header]] - pendingLogsSubs *SyncMap[PendingLogsSubID, Sub[types.Logs]] - pendingBlockSubs *SyncMap[PendingBlockSubID, Sub[*types.Block]] - pendingTxsSubs *SyncMap[PendingTxsSubID, Sub[[]types.Transaction]] + headsSubs *concurrent.SyncMap[HeadsSubID, Sub[*types.Header]] + pendingLogsSubs *concurrent.SyncMap[PendingLogsSubID, Sub[types.Logs]] + pendingBlockSubs *concurrent.SyncMap[PendingBlockSubID, Sub[*types.Block]] + pendingTxsSubs *concurrent.SyncMap[PendingTxsSubID, Sub[[]types.Transaction]] logsSubs *LogsFilterAggregator logsRequestor atomic.Value onNewSnapshot func() storeMu sync.Mutex - logsStores *SyncMap[LogsSubID, []*types.Log] - pendingHeadsStores *SyncMap[HeadsSubID, []*types.Header] - pendingTxsStores *SyncMap[PendingTxsSubID, [][]types.Transaction] + logsStores *concurrent.SyncMap[LogsSubID, []*types.Log] + pendingHeadsStores *concurrent.SyncMap[HeadsSubID, []*types.Header] + pendingTxsStores *concurrent.SyncMap[PendingTxsSubID, [][]types.Transaction] logger log.Logger } +// New creates a new Filters instance, initializes it, and starts subscription goroutines for Ethereum events. +// It requires a context, Ethereum backend, transaction pool client, mining client, snapshot callback function, +// and a logger for logging events. func New(ctx context.Context, ethBackend ApiBackend, txPool txpool.TxpoolClient, mining txpool.MiningClient, onNewSnapshot func(), logger log.Logger) *Filters { logger.Info("rpc filters: subscribing to Erigon events") ff := &Filters{ - headsSubs: NewSyncMap[HeadsSubID, Sub[*types.Header]](), - pendingTxsSubs: NewSyncMap[PendingTxsSubID, Sub[[]types.Transaction]](), - pendingLogsSubs: NewSyncMap[PendingLogsSubID, Sub[types.Logs]](), - pendingBlockSubs: NewSyncMap[PendingBlockSubID, Sub[*types.Block]](), + headsSubs: concurrent.NewSyncMap[HeadsSubID, Sub[*types.Header]](), + pendingTxsSubs: concurrent.NewSyncMap[PendingTxsSubID, Sub[[]types.Transaction]](), + pendingLogsSubs: concurrent.NewSyncMap[PendingLogsSubID, Sub[types.Logs]](), + pendingBlockSubs: concurrent.NewSyncMap[PendingBlockSubID, Sub[*types.Block]](), logsSubs: NewLogsFilterAggregator(), onNewSnapshot: onNewSnapshot, - logsStores: NewSyncMap[LogsSubID, []*types.Log](), - pendingHeadsStores: NewSyncMap[HeadsSubID, []*types.Header](), - pendingTxsStores: NewSyncMap[PendingTxsSubID, [][]types.Transaction](), + logsStores: concurrent.NewSyncMap[LogsSubID, []*types.Log](), + pendingHeadsStores: concurrent.NewSyncMap[HeadsSubID, []*types.Header](), + pendingTxsStores: concurrent.NewSyncMap[PendingTxsSubID, [][]types.Transaction](), logger: logger, } @@ -185,12 +192,15 @@ func New(ctx context.Context, ethBackend ApiBackend, txPool txpool.TxpoolClient, return ff } +// LastPendingBlock returns the last pending block that was received. func (ff *Filters) LastPendingBlock() *types.Block { ff.mu.RLock() defer ff.mu.RUnlock() return ff.pendingBlock } +// subscribeToPendingTransactions subscribes to pending transactions using the given transaction pool client. +// It listens for new transactions and processes them as they arrive. func (ff *Filters) subscribeToPendingTransactions(ctx context.Context, txPool txpool.TxpoolClient) error { subscription, err := txPool.OnAdd(ctx, &txpool.OnAddRequest{}, grpc.WaitForReady(true)) if err != nil { @@ -211,6 +221,8 @@ func (ff *Filters) subscribeToPendingTransactions(ctx context.Context, txPool tx return nil } +// subscribeToPendingBlocks subscribes to pending blocks using the given mining client. +// It listens for new pending blocks and processes them as they arrive. func (ff *Filters) subscribeToPendingBlocks(ctx context.Context, mining txpool.MiningClient) error { subscription, err := mining.OnPendingBlock(ctx, &txpool.OnPendingBlockRequest{}, grpc.WaitForReady(true)) if err != nil { @@ -237,6 +249,8 @@ func (ff *Filters) subscribeToPendingBlocks(ctx context.Context, mining txpool.M return nil } +// HandlePendingBlock handles a new pending block received from the mining client. +// It updates the internal state and notifies subscribers about the new block. func (ff *Filters) HandlePendingBlock(reply *txpool.OnPendingBlockReply) { b := &types.Block{} if reply == nil || len(reply.RplBlock) == 0 { @@ -256,6 +270,8 @@ func (ff *Filters) HandlePendingBlock(reply *txpool.OnPendingBlockReply) { }) } +// subscribeToPendingLogs subscribes to pending logs using the given mining client. +// It listens for new pending logs and processes them as they arrive. func (ff *Filters) subscribeToPendingLogs(ctx context.Context, mining txpool.MiningClient) error { subscription, err := mining.OnPendingLogs(ctx, &txpool.OnPendingLogsRequest{}, grpc.WaitForReady(true)) if err != nil { @@ -281,6 +297,8 @@ func (ff *Filters) subscribeToPendingLogs(ctx context.Context, mining txpool.Min return nil } +// HandlePendingLogs handles new pending logs received from the mining client. +// It updates the internal state and notifies subscribers about the new logs. func (ff *Filters) HandlePendingLogs(reply *txpool.OnPendingLogsReply) { if len(reply.RplLogs) == 0 { return @@ -295,6 +313,8 @@ func (ff *Filters) HandlePendingLogs(reply *txpool.OnPendingLogsReply) { }) } +// SubscribeNewHeads subscribes to new block headers and returns a channel to receive the headers +// and a subscription ID to manage the subscription. func (ff *Filters) SubscribeNewHeads(size int) (<-chan *types.Header, HeadsSubID) { id := HeadsSubID(generateSubscriptionID()) sub := newChanSub[*types.Header](size) @@ -302,6 +322,8 @@ func (ff *Filters) SubscribeNewHeads(size int) (<-chan *types.Header, HeadsSubID return sub.ch, id } +// UnsubscribeHeads unsubscribes from new block headers using the given subscription ID. +// It returns true if the unsubscription was successful, otherwise false. func (ff *Filters) UnsubscribeHeads(id HeadsSubID) bool { ch, ok := ff.headsSubs.Get(id) if !ok { @@ -315,6 +337,8 @@ func (ff *Filters) UnsubscribeHeads(id HeadsSubID) bool { return true } +// SubscribePendingLogs subscribes to pending logs and returns a channel to receive the logs +// and a subscription ID to manage the subscription. It uses the specified filter criteria. func (ff *Filters) SubscribePendingLogs(size int) (<-chan types.Logs, PendingLogsSubID) { id := PendingLogsSubID(generateSubscriptionID()) sub := newChanSub[types.Logs](size) @@ -322,6 +346,7 @@ func (ff *Filters) SubscribePendingLogs(size int) (<-chan types.Logs, PendingLog return sub.ch, id } +// UnsubscribePendingLogs unsubscribes from pending logs using the given subscription ID. func (ff *Filters) UnsubscribePendingLogs(id PendingLogsSubID) { ch, ok := ff.pendingLogsSubs.Get(id) if !ok { @@ -331,6 +356,8 @@ func (ff *Filters) UnsubscribePendingLogs(id PendingLogsSubID) { ff.pendingLogsSubs.Delete(id) } +// SubscribePendingBlock subscribes to pending blocks and returns a channel to receive the blocks +// and a subscription ID to manage the subscription. func (ff *Filters) SubscribePendingBlock(size int) (<-chan *types.Block, PendingBlockSubID) { id := PendingBlockSubID(generateSubscriptionID()) sub := newChanSub[*types.Block](size) @@ -338,6 +365,7 @@ func (ff *Filters) SubscribePendingBlock(size int) (<-chan *types.Block, Pending return sub.ch, id } +// UnsubscribePendingBlock unsubscribes from pending blocks using the given subscription ID. func (ff *Filters) UnsubscribePendingBlock(id PendingBlockSubID) { ch, ok := ff.pendingBlockSubs.Get(id) if !ok { @@ -347,6 +375,8 @@ func (ff *Filters) UnsubscribePendingBlock(id PendingBlockSubID) { ff.pendingBlockSubs.Delete(id) } +// SubscribePendingTxs subscribes to pending transactions and returns a channel to receive the transactions +// and a subscription ID to manage the subscription. func (ff *Filters) SubscribePendingTxs(size int) (<-chan []types.Transaction, PendingTxsSubID) { id := PendingTxsSubID(generateSubscriptionID()) sub := newChanSub[[]types.Transaction](size) @@ -354,6 +384,8 @@ func (ff *Filters) SubscribePendingTxs(size int) (<-chan []types.Transaction, Pe return sub.ch, id } +// UnsubscribePendingTxs unsubscribes from pending transactions using the given subscription ID. +// It returns true if the unsubscription was successful, otherwise false. func (ff *Filters) UnsubscribePendingTxs(id PendingTxsSubID) bool { ch, ok := ff.pendingTxsSubs.Get(id) if !ok { @@ -367,31 +399,45 @@ func (ff *Filters) UnsubscribePendingTxs(id PendingTxsSubID) bool { return true } -func (ff *Filters) SubscribeLogs(size int, crit filters.FilterCriteria) (<-chan *types.Log, LogsSubID) { +// SubscribeLogs subscribes to logs using the specified filter criteria and returns a channel to receive the logs +// and a subscription ID to manage the subscription. +func (ff *Filters) SubscribeLogs(size int, criteria filters.FilterCriteria) (<-chan *types.Log, LogsSubID) { sub := newChanSub[*types.Log](size) id, f := ff.logsSubs.insertLogsFilter(sub) - f.addrs = map[libcommon.Address]int{} - if len(crit.Addresses) == 0 { + + // Initialize address and topic maps + f.addrs = concurrent.NewSyncMap[libcommon.Address, int]() + f.topics = concurrent.NewSyncMap[libcommon.Hash, int]() + + // Handle addresses + if len(criteria.Addresses) == 0 { + // If no addresses are specified, it means all addresses should be included f.allAddrs = 1 } else { - for _, addr := range crit.Addresses { - f.addrs[addr] = 1 + for _, addr := range criteria.Addresses { + f.addrs.Put(addr, 1) } } - f.topics = map[libcommon.Hash]int{} - if len(crit.Topics) == 0 { + + // Handle topics + if len(criteria.Topics) == 0 { + // If no topics are specified, it means all topics should be included f.allTopics = 1 } else { - for _, topics := range crit.Topics { + for _, topics := range criteria.Topics { for _, topic := range topics { - f.topics[topic] = 1 + f.topics.Put(topic, 1) } } } - f.topicsOriginal = crit.Topics + + // Store original topics for reference + f.topicsOriginal = criteria.Topics + + // Add the filter to the list of log filters ff.logsSubs.addLogsFilters(f) - // if any filter in the aggregate needs all addresses or all topics then the global log subscription needs to - // allow all addresses or topics through + + // Create a filter request based on the aggregated filters lfr := ff.logsSubs.createFilterRequest() addresses, topics := ff.logsSubs.getAggMaps() for addr := range addresses { @@ -412,12 +458,15 @@ func (ff *Filters) SubscribeLogs(size int, crit filters.FilterCriteria) (<-chan return sub.ch, id } +// loadLogsRequester loads the current logs requester and returns it. func (ff *Filters) loadLogsRequester() any { ff.mu.Lock() defer ff.mu.Unlock() return ff.logsRequestor.Load() } +// UnsubscribeLogs unsubscribes from logs using the given subscription ID. +// It returns true if the unsubscription was successful, otherwise false. func (ff *Filters) UnsubscribeLogs(id LogsSubID) bool { isDeleted := ff.logsSubs.removeLogsFilter(id) // if any filters in the aggregate need all addresses or all topics then the request to the central @@ -445,11 +494,12 @@ func (ff *Filters) UnsubscribeLogs(id LogsSubID) bool { return isDeleted } +// deleteLogStore deletes the log store associated with the given subscription ID. func (ff *Filters) deleteLogStore(id LogsSubID) { ff.logsStores.Delete(id) } -// OnNewEvent is called when there is a new Event from the remote +// OnNewEvent is called when there is a new event from the remote and processes it. func (ff *Filters) OnNewEvent(event *remote.SubscribeReply) { err := ff.onNewEvent(event) if err != nil { @@ -457,6 +507,7 @@ func (ff *Filters) OnNewEvent(event *remote.SubscribeReply) { } } +// onNewEvent processes the given event from the remote and updates the internal state. func (ff *Filters) onNewEvent(event *remote.SubscribeReply) error { switch event.Type { case remote.Event_HEADER: @@ -474,6 +525,7 @@ func (ff *Filters) onNewEvent(event *remote.SubscribeReply) error { } // TODO: implement? +// onPendingLog handles a new pending log event from the remote. func (ff *Filters) onPendingLog(event *remote.SubscribeReply) error { // payload := event.Data // var logs types.Logs @@ -490,6 +542,7 @@ func (ff *Filters) onPendingLog(event *remote.SubscribeReply) error { } // TODO: implement? +// onPendingBlock handles a new pending block event from the remote. func (ff *Filters) onPendingBlock(event *remote.SubscribeReply) error { // payload := event.Data // var block types.Block @@ -505,6 +558,7 @@ func (ff *Filters) onPendingBlock(event *remote.SubscribeReply) error { return nil } +// onNewHeader handles a new block header event from the remote and updates the internal state. func (ff *Filters) onNewHeader(event *remote.SubscribeReply) error { payload := event.Data var header types.Header @@ -521,6 +575,7 @@ func (ff *Filters) onNewHeader(event *remote.SubscribeReply) error { }) } +// OnNewTx handles a new transaction event from the transaction pool and processes it. func (ff *Filters) OnNewTx(reply *txpool.OnAddReply) { txs := make([]types.Transaction, len(reply.RplTxs)) for i, rlpTx := range reply.RplTxs { @@ -541,11 +596,12 @@ func (ff *Filters) OnNewTx(reply *txpool.OnAddReply) { }) } -// OnNewLogs is called when there is a new log +// OnNewLogs handles a new log event from the remote and processes it. func (ff *Filters) OnNewLogs(reply *remote.SubscribeLogsReply) { ff.logsSubs.distributeLog(reply) } +// AddLogs adds logs to the store associated with the given subscription ID. func (ff *Filters) AddLogs(id LogsSubID, logs *types.Log) { ff.logsStores.DoAndStore(id, func(st []*types.Log, ok bool) []*types.Log { if !ok { @@ -556,6 +612,8 @@ func (ff *Filters) AddLogs(id LogsSubID, logs *types.Log) { }) } +// ReadLogs reads logs from the store associated with the given subscription ID. +// It returns the logs and a boolean indicating whether the logs were found. func (ff *Filters) ReadLogs(id LogsSubID) ([]*types.Log, bool) { res, ok := ff.logsStores.Delete(id) if !ok { @@ -564,6 +622,7 @@ func (ff *Filters) ReadLogs(id LogsSubID) ([]*types.Log, bool) { return res, true } +// AddPendingBlock adds a pending block header to the store associated with the given subscription ID. func (ff *Filters) AddPendingBlock(id HeadsSubID, block *types.Header) { ff.pendingHeadsStores.DoAndStore(id, func(st []*types.Header, ok bool) []*types.Header { if !ok { @@ -574,6 +633,8 @@ func (ff *Filters) AddPendingBlock(id HeadsSubID, block *types.Header) { }) } +// ReadPendingBlocks reads pending block headers from the store associated with the given subscription ID. +// It returns the block headers and a boolean indicating whether the headers were found. func (ff *Filters) ReadPendingBlocks(id HeadsSubID) ([]*types.Header, bool) { res, ok := ff.pendingHeadsStores.Delete(id) if !ok { @@ -582,6 +643,7 @@ func (ff *Filters) ReadPendingBlocks(id HeadsSubID) ([]*types.Header, bool) { return res, true } +// AddPendingTxs adds pending transactions to the store associated with the given subscription ID. func (ff *Filters) AddPendingTxs(id PendingTxsSubID, txs []types.Transaction) { ff.pendingTxsStores.DoAndStore(id, func(st [][]types.Transaction, ok bool) [][]types.Transaction { if !ok { @@ -592,6 +654,8 @@ func (ff *Filters) AddPendingTxs(id PendingTxsSubID, txs []types.Transaction) { }) } +// ReadPendingTxs reads pending transactions from the store associated with the given subscription ID. +// It returns the transactions and a boolean indicating whether the transactions were found. func (ff *Filters) ReadPendingTxs(id PendingTxsSubID) ([][]types.Transaction, bool) { res, ok := ff.pendingTxsStores.Delete(id) if !ok { diff --git a/turbo/rpchelper/filters_test.go b/turbo/rpchelper/filters_test.go index 5f4e10b1d28..b66a36e2901 100644 --- a/turbo/rpchelper/filters_test.go +++ b/turbo/rpchelper/filters_test.go @@ -270,7 +270,7 @@ func TestFilters_SubscribeLogsGeneratesCorrectLogFilterRequest(t *testing.T) { if lastFilterRequest.AllTopics == false { t.Error("2: expected all topics to be true") } - if len(lastFilterRequest.Addresses) != 1 && lastFilterRequest.Addresses[0] != address1H160 { + if len(lastFilterRequest.Addresses) != 1 && gointerfaces.ConvertH160toAddress(lastFilterRequest.Addresses[0]) != gointerfaces.ConvertH160toAddress(address1H160) { t.Error("2: expected the address to match the last request") } @@ -288,10 +288,10 @@ func TestFilters_SubscribeLogsGeneratesCorrectLogFilterRequest(t *testing.T) { if lastFilterRequest.AllTopics == false { t.Error("3: expected all topics to be true") } - if len(lastFilterRequest.Addresses) != 1 && lastFilterRequest.Addresses[0] != address1H160 { + if len(lastFilterRequest.Addresses) != 1 && gointerfaces.ConvertH160toAddress(lastFilterRequest.Addresses[0]) != gointerfaces.ConvertH160toAddress(address1H160) { t.Error("3: expected the address to match the previous request") } - if len(lastFilterRequest.Topics) != 1 && lastFilterRequest.Topics[0] != topic1H256 { + if len(lastFilterRequest.Topics) != 1 && gointerfaces.ConvertH256ToHash(lastFilterRequest.Topics[0]) != gointerfaces.ConvertH256ToHash(topic1H256) { t.Error("3: expected the topics to match the last request") } @@ -307,10 +307,10 @@ func TestFilters_SubscribeLogsGeneratesCorrectLogFilterRequest(t *testing.T) { if lastFilterRequest.AllTopics == false { t.Error("4: expected all topics to be true") } - if len(lastFilterRequest.Addresses) != 1 && lastFilterRequest.Addresses[0] != address1H160 { + if len(lastFilterRequest.Addresses) != 1 && gointerfaces.ConvertH160toAddress(lastFilterRequest.Addresses[0]) != gointerfaces.ConvertH160toAddress(address1H160) { t.Error("4: expected an address to be present") } - if len(lastFilterRequest.Topics) != 1 && lastFilterRequest.Topics[0] != topic1H256 { + if len(lastFilterRequest.Topics) != 1 && gointerfaces.ConvertH256ToHash(lastFilterRequest.Topics[0]) != gointerfaces.ConvertH256ToHash(topic1H256) { t.Error("4: expected a topic to be present") } @@ -327,7 +327,7 @@ func TestFilters_SubscribeLogsGeneratesCorrectLogFilterRequest(t *testing.T) { if len(lastFilterRequest.Addresses) != 0 { t.Error("5: expected addresses to be empty") } - if len(lastFilterRequest.Topics) != 1 && lastFilterRequest.Topics[0] != topic1H256 { + if len(lastFilterRequest.Topics) != 1 && gointerfaces.ConvertH256ToHash(lastFilterRequest.Topics[0]) != gointerfaces.ConvertH256ToHash(topic1H256) { t.Error("5: expected a topic to be present") } @@ -335,15 +335,15 @@ func TestFilters_SubscribeLogsGeneratesCorrectLogFilterRequest(t *testing.T) { // and nothing in the address or topics lists f.UnsubscribeLogs(id3) if lastFilterRequest.AllAddresses == true { - t.Error("5: expected all addresses to be false") + t.Error("6: expected all addresses to be false") } if lastFilterRequest.AllTopics == true { - t.Error("5: expected all topics to be false") + t.Error("6: expected all topics to be false") } if len(lastFilterRequest.Addresses) != 0 { - t.Error("5: expected addresses to be empty") + t.Error("6: expected addresses to be empty") } if len(lastFilterRequest.Topics) != 0 { - t.Error("5: expected topics to be empty") + t.Error("6: expected topics to be empty") } } diff --git a/turbo/rpchelper/logsfilter.go b/turbo/rpchelper/logsfilter.go index f7d598d670b..07321609627 100644 --- a/turbo/rpchelper/logsfilter.go +++ b/turbo/rpchelper/logsfilter.go @@ -4,6 +4,7 @@ import ( "sync" libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon-lib/common/concurrent" "github.com/ledgerwatch/erigon-lib/gointerfaces" remote "github.com/ledgerwatch/erigon-lib/gointerfaces/remoteproto" @@ -11,56 +12,76 @@ import ( ) type LogsFilterAggregator struct { - aggLogsFilter LogsFilter // Aggregation of all current log filters - logsFilters *SyncMap[LogsSubID, *LogsFilter] // Filter for each subscriber, keyed by filterID + aggLogsFilter LogsFilter // Aggregation of all current log filters + logsFilters *concurrent.SyncMap[LogsSubID, *LogsFilter] // Filter for each subscriber, keyed by filterID logsFilterLock sync.RWMutex } // LogsFilter is used for both representing log filter for a specific subscriber (RPC daemon usually) // and "aggregated" log filter representing a union of all subscribers. Therefore, the values in -// the mappings are counters (of type int) and they get deleted when counter goes back to 0 -// Also, addAddr and allTopic are int instead of bool because they are also counter, counting -// how many subscribers have this set on +// the mappings are counters (of type int) and they get deleted when counter goes back to 0. +// Also, addAddr and allTopic are int instead of bool because they are also counters, counting +// how many subscribers have this set on. type LogsFilter struct { allAddrs int - addrs map[libcommon.Address]int + addrs *concurrent.SyncMap[libcommon.Address, int] allTopics int - topics map[libcommon.Hash]int + topics *concurrent.SyncMap[libcommon.Hash, int] topicsOriginal [][]libcommon.Hash // Original topic filters to be applied before distributing to individual subscribers sender Sub[*types2.Log] // nil for aggregate subscriber, for appropriate stream server otherwise } +// Send sends a log to the subscriber represented by the LogsFilter. +// It forwards the log to the subscriber's sender. func (l *LogsFilter) Send(lg *types2.Log) { l.sender.Send(lg) } + +// Close closes the sender associated with the LogsFilter. +// It is used to properly clean up and release resources associated with the sender. func (l *LogsFilter) Close() { l.sender.Close() } +// NewLogsFilterAggregator creates and returns a new instance of LogsFilterAggregator. +// It initializes the aggregated log filter and the map of individual log filters. func NewLogsFilterAggregator() *LogsFilterAggregator { return &LogsFilterAggregator{ aggLogsFilter: LogsFilter{ - addrs: make(map[libcommon.Address]int), - topics: make(map[libcommon.Hash]int), + addrs: concurrent.NewSyncMap[libcommon.Address, int](), + topics: concurrent.NewSyncMap[libcommon.Hash, int](), }, - logsFilters: NewSyncMap[LogsSubID, *LogsFilter](), + logsFilters: concurrent.NewSyncMap[LogsSubID, *LogsFilter](), } } +// insertLogsFilter inserts a new log filter into the LogsFilterAggregator with the specified sender. +// It generates a new filter ID, creates a new LogsFilter, and adds it to the logsFilters map. func (a *LogsFilterAggregator) insertLogsFilter(sender Sub[*types2.Log]) (LogsSubID, *LogsFilter) { + a.logsFilterLock.Lock() + defer a.logsFilterLock.Unlock() filterId := LogsSubID(generateSubscriptionID()) - filter := &LogsFilter{addrs: map[libcommon.Address]int{}, topics: map[libcommon.Hash]int{}, sender: sender} + filter := &LogsFilter{ + addrs: concurrent.NewSyncMap[libcommon.Address, int](), + topics: concurrent.NewSyncMap[libcommon.Hash, int](), + sender: sender, + } a.logsFilters.Put(filterId, filter) return filterId, filter } +// removeLogsFilter removes a log filter identified by filterId from the LogsFilterAggregator. +// It closes the filter and subtracts its addresses and topics from the aggregated filter. func (a *LogsFilterAggregator) removeLogsFilter(filterId LogsSubID) bool { + a.logsFilterLock.Lock() + defer a.logsFilterLock.Unlock() + filter, ok := a.logsFilters.Get(filterId) if !ok { return false } filter.Close() - filter, ok = a.logsFilters.Delete(filterId) + _, ok = a.logsFilters.Delete(filterId) if !ok { return false } @@ -68,6 +89,8 @@ func (a *LogsFilterAggregator) removeLogsFilter(filterId LogsSubID) bool { return true } +// createFilterRequest creates a LogsFilterRequest from the current state of the LogsFilterAggregator. +// It generates a request that represents the union of all current log filters. func (a *LogsFilterAggregator) createFilterRequest() *remote.LogsFilterRequest { a.logsFilterLock.RLock() defer a.logsFilterLock.RUnlock() @@ -77,56 +100,88 @@ func (a *LogsFilterAggregator) createFilterRequest() *remote.LogsFilterRequest { } } +// subtractLogFilters subtracts the counts of addresses and topics in the given LogsFilter from the aggregated filter. +// It decrements the counters for each address and topic in the aggregated filter by the corresponding counts in the +// provided LogsFilter. If the count for any address or topic reaches zero, it is removed from the aggregated filter. func (a *LogsFilterAggregator) subtractLogFilters(f *LogsFilter) { - a.logsFilterLock.Lock() - defer a.logsFilterLock.Unlock() a.aggLogsFilter.allAddrs -= f.allAddrs - for addr, count := range f.addrs { - a.aggLogsFilter.addrs[addr] -= count - if a.aggLogsFilter.addrs[addr] == 0 { - delete(a.aggLogsFilter.addrs, addr) - } - } + f.addrs.Range(func(addr libcommon.Address, count int) error { + a.aggLogsFilter.addrs.Do(addr, func(value int, exists bool) (int, bool) { + if exists { + newValue := value - count + if newValue <= 0 { + return 0, false + } + return newValue, true + } + return 0, false + }) + return nil + }) a.aggLogsFilter.allTopics -= f.allTopics - for topic, count := range f.topics { - a.aggLogsFilter.topics[topic] -= count - if a.aggLogsFilter.topics[topic] == 0 { - delete(a.aggLogsFilter.topics, topic) - } - } + f.topics.Range(func(topic libcommon.Hash, count int) error { + a.aggLogsFilter.topics.Do(topic, func(value int, exists bool) (int, bool) { + if exists { + newValue := value - count + if newValue <= 0 { + return 0, false + } + return newValue, true + } + return 0, false + }) + return nil + }) } +// addLogsFilters adds the counts of addresses and topics in the given LogsFilter to the aggregated filter. +// It increments the counters for each address and topic in the aggregated filter by the corresponding counts in the +// provided LogsFilter. func (a *LogsFilterAggregator) addLogsFilters(f *LogsFilter) { a.logsFilterLock.Lock() defer a.logsFilterLock.Unlock() a.aggLogsFilter.allAddrs += f.allAddrs - for addr, count := range f.addrs { - a.aggLogsFilter.addrs[addr] += count - } + f.addrs.Range(func(addr libcommon.Address, count int) error { + a.aggLogsFilter.addrs.DoAndStore(addr, func(value int, exists bool) int { + return value + count + }) + return nil + }) a.aggLogsFilter.allTopics += f.allTopics - for topic, count := range f.topics { - a.aggLogsFilter.topics[topic] += count - } + f.topics.Range(func(topic libcommon.Hash, count int) error { + a.aggLogsFilter.topics.DoAndStore(topic, func(value int, exists bool) int { + return value + count + }) + return nil + }) } +// getAggMaps returns the aggregated maps of addresses and topics from the LogsFilterAggregator. +// It creates copies of the current state of the aggregated addresses and topics filters. func (a *LogsFilterAggregator) getAggMaps() (map[libcommon.Address]int, map[libcommon.Hash]int) { a.logsFilterLock.RLock() defer a.logsFilterLock.RUnlock() addresses := make(map[libcommon.Address]int) - for k, v := range a.aggLogsFilter.addrs { + a.aggLogsFilter.addrs.Range(func(k libcommon.Address, v int) error { addresses[k] = v - } + return nil + }) topics := make(map[libcommon.Hash]int) - for k, v := range a.aggLogsFilter.topics { + a.aggLogsFilter.topics.Range(func(k libcommon.Hash, v int) error { topics[k] = v - } + return nil + }) return addresses, topics } +// distributeLog processes an event log and distributes it to all subscribed log filters. +// It checks each filter to determine if the log should be sent based on the filter's address and topic settings. func (a *LogsFilterAggregator) distributeLog(eventLog *remote.SubscribeLogsReply) error { + a.logsFilterLock.RLock() + defer a.logsFilterLock.RUnlock() a.logsFilters.Range(func(k LogsSubID, filter *LogsFilter) error { if filter.allAddrs == 0 { - _, addrOk := filter.addrs[gointerfaces.ConvertH160toAddress(eventLog.Address)] + _, addrOk := filter.addrs.Get(gointerfaces.ConvertH160toAddress(eventLog.Address)) if !addrOk { return nil } @@ -157,10 +212,12 @@ func (a *LogsFilterAggregator) distributeLog(eventLog *remote.SubscribeLogsReply return nil } +// chooseTopics checks if the log topics match the filter's topics. +// It returns true if the log topics match the filter's topics, otherwise false. func (a *LogsFilterAggregator) chooseTopics(filter *LogsFilter, logTopics []libcommon.Hash) bool { var found bool for _, logTopic := range logTopics { - if _, ok := filter.topics[logTopic]; ok { + if _, ok := filter.topics.Get(logTopic); ok { found = true break } diff --git a/turbo/rpchelper/subscription.go b/turbo/rpchelper/subscription.go index 6fb57b151d0..e86e46f52de 100644 --- a/turbo/rpchelper/subscription.go +++ b/turbo/rpchelper/subscription.go @@ -45,69 +45,3 @@ func (s *chan_sub[T]) Close() { s.closed = true close(s.ch) } - -func NewSyncMap[K comparable, T any]() *SyncMap[K, T] { - return &SyncMap[K, T]{ - m: make(map[K]T), - } -} - -type SyncMap[K comparable, T any] struct { - m map[K]T - mu sync.RWMutex -} - -func (m *SyncMap[K, T]) Get(k K) (res T, ok bool) { - m.mu.RLock() - defer m.mu.RUnlock() - res, ok = m.m[k] - return res, ok -} - -func (m *SyncMap[K, T]) Put(k K, v T) (T, bool) { - m.mu.Lock() - defer m.mu.Unlock() - old, ok := m.m[k] - m.m[k] = v - return old, ok -} - -func (m *SyncMap[K, T]) Do(k K, fn func(T, bool) (T, bool)) (after T, ok bool) { - m.mu.Lock() - defer m.mu.Unlock() - val, ok := m.m[k] - nv, save := fn(val, ok) - if save { - m.m[k] = nv - } - return nv, ok -} - -func (m *SyncMap[K, T]) DoAndStore(k K, fn func(t T, ok bool) T) (after T, ok bool) { - return m.Do(k, func(t T, b bool) (T, bool) { - res := fn(t, b) - return res, true - }) -} - -func (m *SyncMap[K, T]) Range(fn func(k K, v T) error) error { - m.mu.RLock() - defer m.mu.RUnlock() - for k, v := range m.m { - if err := fn(k, v); err != nil { - return err - } - } - return nil -} - -func (m *SyncMap[K, T]) Delete(k K) (t T, deleted bool) { - m.mu.Lock() - defer m.mu.Unlock() - val, ok := m.m[k] - if !ok { - return t, false - } - delete(m.m, k) - return val, true -}