diff --git a/protocol/chainlib/grpc.go b/protocol/chainlib/grpc.go index 68367b1eca..f2425c12c7 100644 --- a/protocol/chainlib/grpc.go +++ b/protocol/chainlib/grpc.go @@ -9,7 +9,6 @@ import ( "net/http" "strconv" "strings" - "sync" "time" "github.com/goccy/go-json" @@ -50,25 +49,6 @@ type GrpcNodeErrorResponse struct { ErrorCode uint32 `json:"error_code"` } -type grpcDescriptorCache struct { - cachedDescriptors sync.Map // method name is the key, method descriptor is the value -} - -func (gdc *grpcDescriptorCache) getDescriptor(methodName string) *desc.MethodDescriptor { - if descriptor, ok := gdc.cachedDescriptors.Load(methodName); ok { - converted, success := descriptor.(*desc.MethodDescriptor) // convert to a descriptor - if success { - return converted - } - utils.LavaFormatError("Failed Converting method descriptor", nil, utils.Attribute{Key: "Method", Value: methodName}) - } - return nil -} - -func (gdc *grpcDescriptorCache) setDescriptor(methodName string, descriptor *desc.MethodDescriptor) { - gdc.cachedDescriptors.Store(methodName, descriptor) -} - type GrpcChainParser struct { BaseChainParser @@ -388,7 +368,7 @@ func (apil *GrpcChainListener) GetListeningAddress() string { type GrpcChainProxy struct { BaseChainProxy conn grpcConnectorInterface - descriptorsCache *grpcDescriptorCache + descriptorsCache *common.SafeSyncMap[string, *desc.MethodDescriptor] } type grpcConnectorInterface interface { Close() @@ -413,7 +393,7 @@ func NewGrpcChainProxy(ctx context.Context, nConns uint, rpcProviderEndpoint lav func newGrpcChainProxy(ctx context.Context, averageBlockTime time.Duration, parser ChainParser, conn grpcConnectorInterface, rpcProviderEndpoint lavasession.RPCProviderEndpoint) (ChainProxy, error) { cp := &GrpcChainProxy{ BaseChainProxy: BaseChainProxy{averageBlockTime: averageBlockTime, ErrorHandler: &GRPCErrorHandler{}, ChainID: rpcProviderEndpoint.ChainID, HashedNodeUrl: chainproxy.HashURL(rpcProviderEndpoint.NodeUrls[0].Url)}, - descriptorsCache: &grpcDescriptorCache{}, + descriptorsCache: &common.SafeSyncMap[string, *desc.MethodDescriptor]{}, } cp.conn = conn if cp.conn == nil { @@ -471,9 +451,12 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, descriptorSource := rpcInterfaceMessages.DescriptorSourceFromServer(cl) svc, methodName := rpcInterfaceMessages.ParseSymbol(nodeMessage.Path) - // check if we have method descriptor already cached. - methodDescriptor := cp.descriptorsCache.getDescriptor(methodName) - if methodDescriptor == nil { // method descriptor not cached yet, need to fetch it and add to cache + // Check if we have method descriptor already cached. + // The reason we do Load and then Store here, instead of LoadOrStore: + // On the worst case scenario, where 2 threads are accessing the map at the same time, the same descriptor will be stored twice. + // It is better than the alternative, which is always creating the descriptor, since the outcome is the same. + methodDescriptor, found, _ := cp.descriptorsCache.Load(methodName) + if !found { // method descriptor not cached yet, need to fetch it and add to cache var descriptor desc.Descriptor if descriptor, err = descriptorSource.FindSymbol(svc); err != nil { return nil, "", nil, utils.LavaFormatError("descriptorSource.FindSymbol", err, utils.Attribute{Key: "GUID", Value: ctx}) @@ -488,7 +471,7 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{}, } // add the descriptor to the chainProxy cache - cp.descriptorsCache.setDescriptor(methodName, methodDescriptor) + cp.descriptorsCache.Store(methodName, methodDescriptor) } msgFactory := dynamic.NewMessageFactoryWithDefaults() diff --git a/protocol/chaintracker/chain_tracker.go b/protocol/chaintracker/chain_tracker.go index 9b70ba07c9..29d6d390d6 100644 --- a/protocol/chaintracker/chain_tracker.go +++ b/protocol/chaintracker/chain_tracker.go @@ -68,6 +68,10 @@ type ChainTracker struct { blockEventsGap []time.Duration blockTimeUpdatables map[blockTimeUpdatable]struct{} pmetrics *metrics.ProviderMetricsManager + + // initial config + averageBlockTime time.Duration + serverAddress string } // this function returns block hashes of the blocks: [from block - to block] inclusive. an additional specific block hash can be provided. order is sorted ascending @@ -570,6 +574,16 @@ func (ct *ChainTracker) serve(ctx context.Context, listenAddr string) error { return nil } +func (ct *ChainTracker) StartAndServe(ctx context.Context) error { + err := ct.start(ctx, ct.averageBlockTime) + if err != nil { + return err + } + + err = ct.serve(ctx, ct.serverAddress) + return err +} + func NewChainTracker(ctx context.Context, chainFetcher ChainFetcher, config ChainTrackerConfig) (chainTracker *ChainTracker, err error) { if !rand.Initialized() { utils.LavaFormatFatal("can't start chainTracker with nil rand source", nil) @@ -598,16 +612,13 @@ func NewChainTracker(ctx context.Context, chainFetcher ChainFetcher, config Chai startupTime: time.Now(), pmetrics: config.Pmetrics, pollingTimeMultiplier: time.Duration(pollingTime), + averageBlockTime: config.AverageBlockTime, + serverAddress: config.ServerAddress, } if chainFetcher == nil { return nil, utils.LavaFormatError("can't start chainTracker with nil chainFetcher argument", nil) } chainTracker.endpoint = chainFetcher.FetchEndpoint() - err = chainTracker.start(ctx, config.AverageBlockTime) - if err != nil { - return nil, err - } - err = chainTracker.serve(ctx, config.ServerAddress) return chainTracker, err } diff --git a/protocol/chaintracker/chain_tracker_test.go b/protocol/chaintracker/chain_tracker_test.go index c0140af616..1ebcf62a21 100644 --- a/protocol/chaintracker/chain_tracker_test.go +++ b/protocol/chaintracker/chain_tracker_test.go @@ -161,6 +161,7 @@ func TestChainTracker(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) + chainTracker.StartAndServe(context.Background()) require.NoError(t, err) for _, advancement := range tt.advancements { for i := 0; i < int(advancement); i++ { @@ -218,6 +219,7 @@ func TestChainTrackerRangeOnly(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) + chainTracker.StartAndServe(context.Background()) require.NoError(t, err) for _, advancement := range tt.advancements { for i := 0; i < int(advancement); i++ { @@ -302,6 +304,7 @@ func TestChainTrackerCallbacks(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback, NewLatestCallback: newBlockCallback} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + chainTracker.StartAndServe(context.Background()) totalAdvancement := 0 t.Run("one long test", func(t *testing.T) { for _, tt := range tests { @@ -368,6 +371,7 @@ func TestChainTrackerFetchSpreadAcrossPollingTime(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: localTimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)} tracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + tracker.StartAndServe(context.Background()) // fool the tracker so it thinks blocks will come every localTimeForPollingMock (ms), and not adjust it's polling timers for i := 0; i < 50; i++ { tracker.AddBlockGap(localTimeForPollingMock, 1) @@ -491,6 +495,7 @@ func TestChainTrackerPollingTimeUpdate(t *testing.T) { mockChainFetcher.AdvanceBlock() chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: play.localTimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)} tracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) + tracker.StartAndServe(context.Background()) tracker.RegisterForBlockTimeUpdates(&mockTimeUpdater) require.NoError(t, err) // initial delay @@ -555,6 +560,7 @@ func TestChainTrackerMaintainMemory(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + chainTracker.StartAndServe(context.Background()) t.Run("one long test", func(t *testing.T) { for _, tt := range tests { utils.LavaFormatInfo(startedTestStr + tt.name) @@ -607,6 +613,7 @@ func TestFindRequestedBlockHash(t *testing.T) { chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)} chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + chainTracker.StartAndServe(context.Background()) latestBlock, onlyLatestBlockData, _, err := chainTracker.GetLatestBlockData(spectypes.LATEST_BLOCK, spectypes.LATEST_BLOCK, spectypes.NOT_APPLICABLE) require.NoError(t, err) require.Equal(t, currentLatestBlockInMock, latestBlock) diff --git a/protocol/common/safe_sync_map.go b/protocol/common/safe_sync_map.go new file mode 100644 index 0000000000..b0e94a421c --- /dev/null +++ b/protocol/common/safe_sync_map.go @@ -0,0 +1,51 @@ +package common + +import ( + "sync" + + "github.com/lavanet/lava/v3/utils" +) + +type SafeSyncMap[K, V any] struct { + localMap sync.Map +} + +func (ssm *SafeSyncMap[K, V]) Store(key K, toSet V) { + ssm.localMap.Store(key, toSet) +} + +func (ssm *SafeSyncMap[K, V]) Load(key K) (ret V, ok bool, err error) { + value, ok := ssm.localMap.Load(key) + if !ok { + return ret, ok, nil + } + ret, ok = value.(V) + if !ok { + return ret, false, utils.LavaFormatError("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil) + } + return ret, true, nil +} + +// LoadOrStore returns the existing value for the key if present. +// Otherwise, it stores and returns the given value. +// The loaded result is true if the value was loaded, false if stored. +// The function returns the value that was loaded or stored. +func (ssm *SafeSyncMap[K, V]) LoadOrStore(key K, value V) (ret V, loaded bool, err error) { + actual, loaded := ssm.localMap.LoadOrStore(key, value) + if loaded { + // loaded from map + var ok bool + ret, ok = actual.(V) + if !ok { + return ret, false, utils.LavaFormatError("invalid usage of sync map, could not cast result into a PolicyUpdater", nil) + } + return ret, true, nil + } + + // stored in map + return value, false, nil +} + +func (ssm *SafeSyncMap[K, V]) Range(f func(key, value any) bool) { + ssm.localMap.Range(f) +} diff --git a/protocol/integration/protocol_test.go b/protocol/integration/protocol_test.go index 05b0273d19..048bafedda 100644 --- a/protocol/integration/protocol_test.go +++ b/protocol/integration/protocol_test.go @@ -293,6 +293,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string mockChainFetcher := NewMockChainFetcher(1000, int64(blocksToSaveChainTracker), nil) chainTracker, err := chaintracker.NewChainTracker(ctx, mockChainFetcher, chainTrackerConfig) require.NoError(t, err) + chainTracker.StartAndServe(ctx) reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser) mockReliabilityManager := NewMockReliabilityManager(reliabilityManager) rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil, false) diff --git a/protocol/rpcconsumer/policies_map.go b/protocol/rpcconsumer/policies_map.go deleted file mode 100644 index d70d2de3da..0000000000 --- a/protocol/rpcconsumer/policies_map.go +++ /dev/null @@ -1,47 +0,0 @@ -package rpcconsumer - -import ( - "sync" - - "github.com/lavanet/lava/v3/protocol/statetracker/updaters" - "github.com/lavanet/lava/v3/utils" -) - -type syncMapPolicyUpdaters struct { - localMap sync.Map -} - -func (sm *syncMapPolicyUpdaters) Store(key string, toSet *updaters.PolicyUpdater) { - sm.localMap.Store(key, toSet) -} - -func (sm *syncMapPolicyUpdaters) Load(key string) (ret *updaters.PolicyUpdater, ok bool) { - value, ok := sm.localMap.Load(key) - if !ok { - return nil, ok - } - ret, ok = value.(*updaters.PolicyUpdater) - if !ok { - utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil) - } - return ret, true -} - -// LoadOrStore returns the existing value for the key if present. -// Otherwise, it stores and returns the given value. -// The loaded result is true if the value was loaded, false if stored. -// The function returns the value that was loaded or stored. -func (sm *syncMapPolicyUpdaters) LoadOrStore(key string, value *updaters.PolicyUpdater) (ret *updaters.PolicyUpdater, loaded bool) { - actual, loaded := sm.localMap.LoadOrStore(key, value) - if loaded { - // loaded from map - ret, loaded = actual.(*updaters.PolicyUpdater) - if !loaded { - utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil) - } - return ret, loaded - } - - // stored in map - return value, false -} diff --git a/protocol/rpcconsumer/rpcconsumer.go b/protocol/rpcconsumer/rpcconsumer.go index 67f4a24461..68083a1834 100644 --- a/protocol/rpcconsumer/rpcconsumer.go +++ b/protocol/rpcconsumer/rpcconsumer.go @@ -174,12 +174,15 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt for _, endpoint := range options.rpcEndpoints { chainMutexes[endpoint.ChainID] = &sync.Mutex{} // create a mutex per chain for shared resources } - var optimizers sync.Map - var consumerConsistencies sync.Map - var finalizationConsensuses sync.Map + + optimizers := &common.SafeSyncMap[string, *provideroptimizer.ProviderOptimizer]{} + consumerConsistencies := &common.SafeSyncMap[string, *ConsumerConsistency]{} + finalizationConsensuses := &common.SafeSyncMap[string, *finalizationconsensus.FinalizationConsensus]{} + var wg sync.WaitGroup parallelJobs := len(options.rpcEndpoints) wg.Add(parallelJobs) + errCh := make(chan error) consumerStateTracker.RegisterForUpdates(ctx, updaters.NewMetricsUpdater(consumerMetricsManager)) @@ -193,7 +196,7 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt } consumerStateTracker.RegisterForVersionUpdates(ctx, version.Version, &upgrade.ProtocolVersion{}) relaysMonitorAggregator := metrics.NewRelaysMonitorAggregator(options.cmdFlags.RelaysHealthIntervalFlag, consumerMetricsManager) - policyUpdaters := syncMapPolicyUpdaters{} + policyUpdaters := &common.SafeSyncMap[string, *updaters.PolicyUpdater]{} for _, rpcEndpoint := range options.rpcEndpoints { go func(rpcEndpoint *lavasession.RPCEndpoint) error { defer wg.Done() @@ -206,7 +209,12 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt chainID := rpcEndpoint.ChainID // create policyUpdaters per chain newPolicyUpdater := updaters.NewPolicyUpdater(chainID, consumerStateTracker, consumerAddr.String(), chainParser, *rpcEndpoint) - if policyUpdater, ok := policyUpdaters.LoadOrStore(chainID, newPolicyUpdater); ok { + policyUpdater, ok, err := policyUpdaters.LoadOrStore(chainID, newPolicyUpdater) + if err != nil { + errCh <- err + return utils.LavaFormatError("failed loading or storing policy updater", err, utils.LogAttr("endpoint", rpcEndpoint)) + } + if ok { err := policyUpdater.AddPolicySetter(chainParser, *rpcEndpoint) if err != nil { errCh <- err @@ -229,46 +237,33 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt // this is locked so we don't race optimizers creation chainMutexes[chainID].Lock() defer chainMutexes[chainID].Unlock() - value, exists := optimizers.Load(chainID) - if !exists { - // doesn't exist for this chain create a new one - baseLatency := common.AverageWorldLatency / 2 // we want performance to be half our timeout or better - optimizer = provideroptimizer.NewProviderOptimizer(options.strategy, averageBlockTime, baseLatency, options.maxConcurrentProviders) - optimizers.Store(chainID, optimizer) - } else { - var ok bool - optimizer, ok = value.(*provideroptimizer.ProviderOptimizer) - if !ok { - err = utils.LavaFormatError("failed loading optimizer, value is of the wrong type", nil, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) - return err - } + var loaded bool + var err error + + baseLatency := common.AverageWorldLatency / 2 // we want performance to be half our timeout or better + + // Create / Use existing optimizer + newOptimizer := provideroptimizer.NewProviderOptimizer(options.strategy, averageBlockTime, baseLatency, options.maxConcurrentProviders) + optimizer, _, err = optimizers.LoadOrStore(chainID, newOptimizer) + if err != nil { + return utils.LavaFormatError("failed loading optimizer", err, utils.LogAttr("endpoint", rpcEndpoint.Key())) } - value, exists = consumerConsistencies.Load(chainID) - if !exists { // doesn't exist for this chain create a new one - consumerConsistency = NewConsumerConsistency(chainID) - consumerConsistencies.Store(chainID, consumerConsistency) - } else { - var ok bool - consumerConsistency, ok = value.(*ConsumerConsistency) - if !ok { - err = utils.LavaFormatError("failed loading consumer consistency, value is of the wrong type", err, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) - return err - } + + // Create / Use existing ConsumerConsistency + newConsumerConsistency := NewConsumerConsistency(chainID) + consumerConsistency, _, err = consumerConsistencies.LoadOrStore(chainID, newConsumerConsistency) + if err != nil { + return utils.LavaFormatError("failed loading consumer consistency", err, utils.LogAttr("endpoint", rpcEndpoint.Key())) } - value, exists = finalizationConsensuses.Load(chainID) - if !exists { - // doesn't exist for this chain create a new one - finalizationConsensus = finalizationconsensus.NewFinalizationConsensus(rpcEndpoint.ChainID) + // Create / Use existing FinalizationConsensus + newFinalizationConsensus := finalizationconsensus.NewFinalizationConsensus(rpcEndpoint.ChainID) + finalizationConsensus, loaded, err = finalizationConsensuses.LoadOrStore(chainID, newFinalizationConsensus) + if err != nil { + return utils.LavaFormatError("failed loading finalization consensus", err, utils.LogAttr("endpoint", rpcEndpoint.Key())) + } + if !loaded { // when creating new finalization consensus instance we need to register it to updates consumerStateTracker.RegisterFinalizationConsensusForUpdates(ctx, finalizationConsensus) - finalizationConsensuses.Store(chainID, finalizationConsensus) - } else { - var ok bool - finalizationConsensus, ok = value.(*finalizationconsensus.FinalizationConsensus) - if !ok { - err = utils.LavaFormatError("failed loading finalization consensus, value is of the wrong type", nil, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) - return err - } } return nil } @@ -278,7 +273,7 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt return err } - if finalizationConsensus == nil || optimizer == nil { + if finalizationConsensus == nil || optimizer == nil || consumerConsistency == nil { err = utils.LavaFormatError("failed getting assets, found a nil", nil, utils.Attribute{Key: "endpoint", Value: rpcEndpoint.Key()}) errCh <- err return err @@ -327,9 +322,9 @@ func (rpcc *RPCConsumer) Start(ctx context.Context, options *rpcConsumerStartOpt utils.LavaFormatDebug("Starting Policy Updaters for all chains") for chainId := range chainMutexes { - policyUpdater, ok := policyUpdaters.Load(chainId) - if !ok { - utils.LavaFormatError("could not load policy Updater for chain", nil, utils.LogAttr("chain", chainId)) + policyUpdater, ok, err := policyUpdaters.Load(chainId) + if !ok || err != nil { + utils.LavaFormatError("could not load policy Updater for chain", err, utils.LogAttr("chain", chainId)) continue } consumerStateTracker.RegisterForPairingUpdates(ctx, policyUpdater, chainId) diff --git a/protocol/rpcconsumer/testing.go b/protocol/rpcconsumer/testing.go index fffc60d214..6c3683186e 100644 --- a/protocol/rpcconsumer/testing.go +++ b/protocol/rpcconsumer/testing.go @@ -79,6 +79,7 @@ func startTesting(ctx context.Context, clientCtx client.Context, rpcEndpoints [] if err != nil { return utils.LavaFormatError("panic severity critical error, aborting support for chain api due to node access, continuing with other endpoints", err, utils.Attribute{Key: "chainTrackerConfig", Value: chainTrackerConfig}, utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint}) } + chainTracker.StartAndServe(ctx) _ = chainTracker // let the chain tracker work and make queries return nil }(rpcProviderEndpoint) diff --git a/protocol/rpcprovider/chain_tackers.go b/protocol/rpcprovider/chain_tackers.go deleted file mode 100644 index 95a43a5ea1..0000000000 --- a/protocol/rpcprovider/chain_tackers.go +++ /dev/null @@ -1,38 +0,0 @@ -package rpcprovider - -import ( - "sync" - - "github.com/lavanet/lava/v3/protocol/chaintracker" - "github.com/lavanet/lava/v3/utils" -) - -type ChainTrackers struct { - stateTrackersPerChain sync.Map -} - -func (ct *ChainTrackers) GetTrackerPerChain(specID string) (chainTracker *chaintracker.ChainTracker, found bool) { - chainTrackerInf, found := ct.stateTrackersPerChain.Load(specID) - if !found { - return nil, found - } - var ok bool - chainTracker, ok = chainTrackerInf.(*chaintracker.ChainTracker) - if !ok { - utils.LavaFormatFatal("invalid usage of syncmap, could not cast result into a chaintracker", nil) - } - return chainTracker, true -} - -func (ct *ChainTrackers) SetTrackerForChain(specId string, chainTracker *chaintracker.ChainTracker) { - ct.stateTrackersPerChain.Store(specId, chainTracker) -} - -func (ct *ChainTrackers) GetLatestBlockNumForSpec(specID string) int64 { - chainTracker, found := ct.GetTrackerPerChain(specID) - if !found { - return 0 - } - latestBlock, _ := chainTracker.GetLatestBlockNum() - return latestBlock -} diff --git a/protocol/rpcprovider/rpcprovider.go b/protocol/rpcprovider/rpcprovider.go index 554d7cfbad..9b860b433b 100644 --- a/protocol/rpcprovider/rpcprovider.go +++ b/protocol/rpcprovider/rpcprovider.go @@ -133,7 +133,7 @@ type RPCProvider struct { parallelConnections uint cache *performance.Cache shardID uint // shardID is a flag that allows setting up multiple provider databases of the same chain - chainTrackers *ChainTrackers + chainTrackers *common.SafeSyncMap[string, *chaintracker.ChainTracker] relaysMonitorAggregator *metrics.RelaysMonitorAggregator relaysHealthCheckEnabled bool relaysHealthCheckInterval time.Duration @@ -152,7 +152,7 @@ func (rpcp *RPCProvider) Start(options *rpcProviderStartOptions) (err error) { cancel() }() rpcp.providerUniqueId = strconv.FormatUint(utils.GenerateUniqueIdentifier(), 10) - rpcp.chainTrackers = &ChainTrackers{} + rpcp.chainTrackers = &common.SafeSyncMap[string, *chaintracker.ChainTracker]{} rpcp.parallelConnections = options.parallelConnections rpcp.cache = options.cache rpcp.providerMetricsManager = metrics.NewProviderMetricsManager(options.metricsListenAddress) // start up prometheus metrics @@ -185,7 +185,7 @@ func (rpcp *RPCProvider) Start(options *rpcProviderStartOptions) (err error) { // single reward server if !options.staticProvider { rewardDB := rewardserver.NewRewardDBWithTTL(options.rewardTTL) - rpcp.rewardServer = rewardserver.NewRewardServer(providerStateTracker, rpcp.providerMetricsManager, rewardDB, options.rewardStoragePath, options.rewardsSnapshotThreshold, options.rewardsSnapshotTimeoutSec, rpcp.chainTrackers) + rpcp.rewardServer = rewardserver.NewRewardServer(providerStateTracker, rpcp.providerMetricsManager, rewardDB, options.rewardStoragePath, options.rewardsSnapshotThreshold, options.rewardsSnapshotTimeoutSec, rpcp) rpcp.providerStateTracker.RegisterForEpochUpdates(ctx, rpcp.rewardServer) rpcp.providerStateTracker.RegisterPaymentUpdatableForPayments(ctx, rpcp.rewardServer) } @@ -409,42 +409,45 @@ func (rpcp *RPCProvider) SetupEndpoint(ctx context.Context, rpcProviderEndpoint chainCommonSetup := func() error { rpcp.chainMutexes[chainID].Lock() defer rpcp.chainMutexes[chainID].Unlock() - var found bool - chainTracker, found = rpcp.chainTrackers.GetTrackerPerChain(chainID) - if !found { - consistencyErrorCallback := func(oldBlock, newBlock int64) { - utils.LavaFormatError("Consistency issue detected", nil, - utils.Attribute{Key: "oldBlock", Value: oldBlock}, - utils.Attribute{Key: "newBlock", Value: newBlock}, - utils.Attribute{Key: "Chain", Value: rpcProviderEndpoint.ChainID}, - utils.Attribute{Key: "apiInterface", Value: apiInterface}, - ) - } - blocksToSaveChainTracker := uint64(blocksToFinalization + blocksInFinalizationData) - chainTrackerConfig := chaintracker.ChainTrackerConfig{ - BlocksToSave: blocksToSaveChainTracker, - AverageBlockTime: averageBlockTime, - ServerBlockMemory: ChainTrackerDefaultMemory + blocksToSaveChainTracker, - NewLatestCallback: recordMetricsOnNewBlock, - ConsistencyCallback: consistencyErrorCallback, - Pmetrics: rpcp.providerMetricsManager, - } + var loaded bool + consistencyErrorCallback := func(oldBlock, newBlock int64) { + utils.LavaFormatError("Consistency issue detected", nil, + utils.Attribute{Key: "oldBlock", Value: oldBlock}, + utils.Attribute{Key: "newBlock", Value: newBlock}, + utils.Attribute{Key: "Chain", Value: rpcProviderEndpoint.ChainID}, + utils.Attribute{Key: "apiInterface", Value: apiInterface}, + ) + } + blocksToSaveChainTracker := uint64(blocksToFinalization + blocksInFinalizationData) + chainTrackerConfig := chaintracker.ChainTrackerConfig{ + BlocksToSave: blocksToSaveChainTracker, + AverageBlockTime: averageBlockTime, + ServerBlockMemory: ChainTrackerDefaultMemory + blocksToSaveChainTracker, + NewLatestCallback: recordMetricsOnNewBlock, + ConsistencyCallback: consistencyErrorCallback, + Pmetrics: rpcp.providerMetricsManager, + } - chainTracker, err = chaintracker.NewChainTracker(ctx, chainFetcher, chainTrackerConfig) - if err != nil { - return utils.LavaFormatError("panic severity critical error, aborting support for chain api due to node access, continuing with other endpoints", err, utils.Attribute{Key: "chainTrackerConfig", Value: chainTrackerConfig}, utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint}) - } + chainTracker, err = chaintracker.NewChainTracker(ctx, chainFetcher, chainTrackerConfig) + if err != nil { + return utils.LavaFormatError("panic severity critical error, aborting support for chain api due to node access, continuing with other endpoints", err, utils.Attribute{Key: "chainTrackerConfig", Value: chainTrackerConfig}, utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint}) + } + + chainTrackerLoaded, loaded, err := rpcp.chainTrackers.LoadOrStore(chainID, chainTracker) + if err != nil { + utils.LavaFormatFatal("failed to load or store chain tracker", err, utils.LogAttr("chainID", chainID)) + } + if !loaded { // this is the first time we are setting up the chain tracker, we need to register for spec verifications + chainTracker.StartAndServe(ctx) utils.LavaFormatDebug("Registering for spec verifications for endpoint", utils.LogAttr("rpcEndpoint", rpcEndpoint)) // we register for spec verifications only once, and this triggers all chainFetchers of that specId when it triggers err = rpcp.providerStateTracker.RegisterForSpecVerifications(ctx, specValidator, rpcEndpoint.ChainID) if err != nil { return utils.LavaFormatError("failed to RegisterForSpecUpdates, panic severity critical error, aborting support for chain api due to invalid chain parser, continuing with others", err, utils.Attribute{Key: "endpoint", Value: rpcProviderEndpoint.String()}) } - - // Any validation needs to be before we store chain tracker for given chain id - rpcp.chainTrackers.SetTrackerForChain(rpcProviderEndpoint.ChainID, chainTracker) - } else { + } else { // loaded an existing chain tracker. use the same one instead + chainTracker = chainTrackerLoaded utils.LavaFormatDebug("reusing chain tracker", utils.Attribute{Key: "chain", Value: rpcProviderEndpoint.ChainID}) } return nil @@ -516,6 +519,19 @@ func (rpcp *RPCProvider) SetupEndpoint(ctx context.Context, rpcProviderEndpoint return nil } +func (rpcp *RPCProvider) GetLatestBlockNumForSpec(specID string) int64 { + chainTracker, found, err := rpcp.chainTrackers.Load(specID) + if err != nil { + utils.LavaFormatFatal("failed to load chain tracker", err, utils.LogAttr("specID", specID)) + } + if !found { + return 0 + } + + block, _ := chainTracker.GetLatestBlockNum() + return block +} + func ParseEndpointsCustomName(viper_endpoints *viper.Viper, endpointsConfigName string, geolocation uint64) (endpoints []*lavasession.RPCProviderEndpoint, err error) { err = viper_endpoints.UnmarshalKey(endpointsConfigName, &endpoints) if err != nil { diff --git a/protocol/statetracker/events.go b/protocol/statetracker/events.go index ff4b8dcccd..fa1383e9f5 100644 --- a/protocol/statetracker/events.go +++ b/protocol/statetracker/events.go @@ -24,6 +24,7 @@ import ( "github.com/lavanet/lava/v3/app" "github.com/lavanet/lava/v3/protocol/chainlib" "github.com/lavanet/lava/v3/protocol/chaintracker" + "github.com/lavanet/lava/v3/protocol/common" "github.com/lavanet/lava/v3/protocol/rpcprovider/rewardserver" updaters "github.com/lavanet/lava/v3/protocol/statetracker/updaters" "github.com/lavanet/lava/v3/utils" @@ -122,6 +123,7 @@ func eventsLookup(ctx context.Context, clientCtx client.Context, blocks, fromBlo if err != nil { return utils.LavaFormatError("failed setting up chain tracker", err) } + chainTracker.StartAndServe(ctx) _ = chainTracker select { case <-ctx.Done(): @@ -666,7 +668,7 @@ func countTransactionsPerDay(ctx context.Context, clientCtx client.Context, bloc // j are blocks in that day // starting from current day and going backwards var wg sync.WaitGroup - totalTxPerDay := sync.Map{} + totalTxPerDay := &common.SafeSyncMap[int64, int]{} // Process each day from the earliest to the latest for i := int64(1); i <= numberOfDays; i++ { @@ -703,14 +705,13 @@ func countTransactionsPerDay(ctx context.Context, clientCtx client.Context, bloc transactionResults := blockResults.TxsResults utils.LavaFormatInfo("Number of tx for block", utils.LogAttr("_routine", end-k), utils.LogAttr("block_number", k), utils.LogAttr("number_of_tx", len(transactionResults))) // Update totalTxPerDay safely - actual, _ := totalTxPerDay.LoadOrStore(i, len(transactionResults)) - if actual != nil { - val, ok := actual.(int) - if !ok { - utils.LavaFormatError("Failed converting int", nil) - return - } - totalTxPerDay.Store(i, val+len(transactionResults)) + actual, loaded, err := totalTxPerDay.LoadOrStore(i, len(transactionResults)) + if err != nil { + utils.LavaFormatError("failed to load or store", err) + return + } + if loaded { + totalTxPerDay.Store(i, actual+len(transactionResults)) } }(k) } diff --git a/protocol/statetracker/state_tracker.go b/protocol/statetracker/state_tracker.go index 96c833781f..5ff1312702 100644 --- a/protocol/statetracker/state_tracker.go +++ b/protocol/statetracker/state_tracker.go @@ -130,6 +130,7 @@ func NewStateTracker(ctx context.Context, txFactory tx.Factory, clientCtx client } cst.AverageBlockTime = chainTrackerConfig.AverageBlockTime cst.chainTracker, err = chaintracker.NewChainTracker(ctx, chainFetcher, chainTrackerConfig) + cst.chainTracker.StartAndServe(ctx) cst.chainTracker.RegisterForBlockTimeUpdates(cst) // registering for block time updates. return cst, err }