Skip to content

Commit

Permalink
fix: PRT: Fix all sync.Map uses (#1688)
Browse files Browse the repository at this point in the history
* Create and use SafeSyncMap

* Change from Load and then Store to LoadOrStore

* fixed elad

* fix elad 2

* Fix chaintracker test

* Add more missing StartAndServe

* Small fix to rpcconsumer

* CR Fix: Add description to why we do Load and then Store

* CR Fix: Log error instead of Fatal when erroring on Load

* tidy code

* push

---------

Co-authored-by: Ran Mishael <ran@lavanet.xyz>
Co-authored-by: Omer <100387053+omerlavanet@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 15, 2024
1 parent 28a53a3 commit aae38b9
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 201 deletions.
35 changes: 9 additions & 26 deletions protocol/chainlib/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/goccy/go-json"
Expand Down Expand Up @@ -50,25 +49,6 @@ type GrpcNodeErrorResponse struct {
ErrorCode uint32 `json:"error_code"`
}

type grpcDescriptorCache struct {
cachedDescriptors sync.Map // method name is the key, method descriptor is the value
}

func (gdc *grpcDescriptorCache) getDescriptor(methodName string) *desc.MethodDescriptor {
if descriptor, ok := gdc.cachedDescriptors.Load(methodName); ok {
converted, success := descriptor.(*desc.MethodDescriptor) // convert to a descriptor
if success {
return converted
}
utils.LavaFormatError("Failed Converting method descriptor", nil, utils.Attribute{Key: "Method", Value: methodName})
}
return nil
}

func (gdc *grpcDescriptorCache) setDescriptor(methodName string, descriptor *desc.MethodDescriptor) {
gdc.cachedDescriptors.Store(methodName, descriptor)
}

type GrpcChainParser struct {
BaseChainParser

Expand Down Expand Up @@ -388,7 +368,7 @@ func (apil *GrpcChainListener) GetListeningAddress() string {
type GrpcChainProxy struct {
BaseChainProxy
conn grpcConnectorInterface
descriptorsCache *grpcDescriptorCache
descriptorsCache *common.SafeSyncMap[string, *desc.MethodDescriptor]
}
type grpcConnectorInterface interface {
Close()
Expand All @@ -413,7 +393,7 @@ func NewGrpcChainProxy(ctx context.Context, nConns uint, rpcProviderEndpoint lav
func newGrpcChainProxy(ctx context.Context, averageBlockTime time.Duration, parser ChainParser, conn grpcConnectorInterface, rpcProviderEndpoint lavasession.RPCProviderEndpoint) (ChainProxy, error) {
cp := &GrpcChainProxy{
BaseChainProxy: BaseChainProxy{averageBlockTime: averageBlockTime, ErrorHandler: &GRPCErrorHandler{}, ChainID: rpcProviderEndpoint.ChainID, HashedNodeUrl: chainproxy.HashURL(rpcProviderEndpoint.NodeUrls[0].Url)},
descriptorsCache: &grpcDescriptorCache{},
descriptorsCache: &common.SafeSyncMap[string, *desc.MethodDescriptor]{},
}
cp.conn = conn
if cp.conn == nil {
Expand Down Expand Up @@ -471,9 +451,12 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
descriptorSource := rpcInterfaceMessages.DescriptorSourceFromServer(cl)
svc, methodName := rpcInterfaceMessages.ParseSymbol(nodeMessage.Path)

// check if we have method descriptor already cached.
methodDescriptor := cp.descriptorsCache.getDescriptor(methodName)
if methodDescriptor == nil { // method descriptor not cached yet, need to fetch it and add to cache
// Check if we have method descriptor already cached.
// The reason we do Load and then Store here, instead of LoadOrStore:
// On the worst case scenario, where 2 threads are accessing the map at the same time, the same descriptor will be stored twice.
// It is better than the alternative, which is always creating the descriptor, since the outcome is the same.
methodDescriptor, found, _ := cp.descriptorsCache.Load(methodName)
if !found { // method descriptor not cached yet, need to fetch it and add to cache
var descriptor desc.Descriptor
if descriptor, err = descriptorSource.FindSymbol(svc); err != nil {
return nil, "", nil, utils.LavaFormatError("descriptorSource.FindSymbol", err, utils.Attribute{Key: "GUID", Value: ctx})
Expand All @@ -488,7 +471,7 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
}

// add the descriptor to the chainProxy cache
cp.descriptorsCache.setDescriptor(methodName, methodDescriptor)
cp.descriptorsCache.Store(methodName, methodDescriptor)
}

msgFactory := dynamic.NewMessageFactoryWithDefaults()
Expand Down
21 changes: 16 additions & 5 deletions protocol/chaintracker/chain_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ type ChainTracker struct {
blockEventsGap []time.Duration
blockTimeUpdatables map[blockTimeUpdatable]struct{}
pmetrics *metrics.ProviderMetricsManager

// initial config
averageBlockTime time.Duration
serverAddress string
}

// this function returns block hashes of the blocks: [from block - to block] inclusive. an additional specific block hash can be provided. order is sorted ascending
Expand Down Expand Up @@ -570,6 +574,16 @@ func (ct *ChainTracker) serve(ctx context.Context, listenAddr string) error {
return nil
}

func (ct *ChainTracker) StartAndServe(ctx context.Context) error {
err := ct.start(ctx, ct.averageBlockTime)
if err != nil {
return err
}

err = ct.serve(ctx, ct.serverAddress)
return err
}

func NewChainTracker(ctx context.Context, chainFetcher ChainFetcher, config ChainTrackerConfig) (chainTracker *ChainTracker, err error) {
if !rand.Initialized() {
utils.LavaFormatFatal("can't start chainTracker with nil rand source", nil)
Expand Down Expand Up @@ -598,16 +612,13 @@ func NewChainTracker(ctx context.Context, chainFetcher ChainFetcher, config Chai
startupTime: time.Now(),
pmetrics: config.Pmetrics,
pollingTimeMultiplier: time.Duration(pollingTime),
averageBlockTime: config.AverageBlockTime,
serverAddress: config.ServerAddress,
}
if chainFetcher == nil {
return nil, utils.LavaFormatError("can't start chainTracker with nil chainFetcher argument", nil)
}
chainTracker.endpoint = chainFetcher.FetchEndpoint()
err = chainTracker.start(ctx, config.AverageBlockTime)
if err != nil {
return nil, err
}

err = chainTracker.serve(ctx, config.ServerAddress)
return chainTracker, err
}
7 changes: 7 additions & 0 deletions protocol/chaintracker/chain_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ func TestChainTracker(t *testing.T) {

chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
chainTracker.StartAndServe(context.Background())
require.NoError(t, err)
for _, advancement := range tt.advancements {
for i := 0; i < int(advancement); i++ {
Expand Down Expand Up @@ -218,6 +219,7 @@ func TestChainTrackerRangeOnly(t *testing.T) {

chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
chainTracker.StartAndServe(context.Background())
require.NoError(t, err)
for _, advancement := range tt.advancements {
for i := 0; i < int(advancement); i++ {
Expand Down Expand Up @@ -302,6 +304,7 @@ func TestChainTrackerCallbacks(t *testing.T) {
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback, NewLatestCallback: newBlockCallback}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
chainTracker.StartAndServe(context.Background())
totalAdvancement := 0
t.Run("one long test", func(t *testing.T) {
for _, tt := range tests {
Expand Down Expand Up @@ -368,6 +371,7 @@ func TestChainTrackerFetchSpreadAcrossPollingTime(t *testing.T) {
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: localTimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)}
tracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
tracker.StartAndServe(context.Background())
// fool the tracker so it thinks blocks will come every localTimeForPollingMock (ms), and not adjust it's polling timers
for i := 0; i < 50; i++ {
tracker.AddBlockGap(localTimeForPollingMock, 1)
Expand Down Expand Up @@ -491,6 +495,7 @@ func TestChainTrackerPollingTimeUpdate(t *testing.T) {
mockChainFetcher.AdvanceBlock()
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: play.localTimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)}
tracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
tracker.StartAndServe(context.Background())
tracker.RegisterForBlockTimeUpdates(&mockTimeUpdater)
require.NoError(t, err)
// initial delay
Expand Down Expand Up @@ -555,6 +560,7 @@ func TestChainTrackerMaintainMemory(t *testing.T) {
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
chainTracker.StartAndServe(context.Background())
t.Run("one long test", func(t *testing.T) {
for _, tt := range tests {
utils.LavaFormatInfo(startedTestStr + tt.name)
Expand Down Expand Up @@ -607,6 +613,7 @@ func TestFindRequestedBlockHash(t *testing.T) {
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
chainTracker.StartAndServe(context.Background())
latestBlock, onlyLatestBlockData, _, err := chainTracker.GetLatestBlockData(spectypes.LATEST_BLOCK, spectypes.LATEST_BLOCK, spectypes.NOT_APPLICABLE)
require.NoError(t, err)
require.Equal(t, currentLatestBlockInMock, latestBlock)
Expand Down
51 changes: 51 additions & 0 deletions protocol/common/safe_sync_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package common

import (
"sync"

"github.com/lavanet/lava/v3/utils"
)

type SafeSyncMap[K, V any] struct {
localMap sync.Map
}

func (ssm *SafeSyncMap[K, V]) Store(key K, toSet V) {
ssm.localMap.Store(key, toSet)
}

func (ssm *SafeSyncMap[K, V]) Load(key K) (ret V, ok bool, err error) {
value, ok := ssm.localMap.Load(key)
if !ok {
return ret, ok, nil
}
ret, ok = value.(V)
if !ok {
return ret, false, utils.LavaFormatError("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil)
}
return ret, true, nil
}

// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
// The function returns the value that was loaded or stored.
func (ssm *SafeSyncMap[K, V]) LoadOrStore(key K, value V) (ret V, loaded bool, err error) {
actual, loaded := ssm.localMap.LoadOrStore(key, value)
if loaded {
// loaded from map
var ok bool
ret, ok = actual.(V)
if !ok {
return ret, false, utils.LavaFormatError("invalid usage of sync map, could not cast result into a PolicyUpdater", nil)
}
return ret, true, nil
}

// stored in map
return value, false, nil
}

func (ssm *SafeSyncMap[K, V]) Range(f func(key, value any) bool) {
ssm.localMap.Range(f)
}
1 change: 1 addition & 0 deletions protocol/integration/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string
mockChainFetcher := NewMockChainFetcher(1000, int64(blocksToSaveChainTracker), nil)
chainTracker, err := chaintracker.NewChainTracker(ctx, mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
chainTracker.StartAndServe(ctx)
reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser)
mockReliabilityManager := NewMockReliabilityManager(reliabilityManager)
rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil, false)
Expand Down
47 changes: 0 additions & 47 deletions protocol/rpcconsumer/policies_map.go

This file was deleted.

Loading

0 comments on commit aae38b9

Please sign in to comment.