Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: PRT: Fix all sync.Map uses #1688

Merged
merged 13 commits into from
Sep 15, 2024
35 changes: 9 additions & 26 deletions protocol/chainlib/grpc.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"net/http"
"strconv"
"strings"
"sync"
"time"

"github.com/goccy/go-json"
Expand Down Expand Up @@ -50,25 +49,6 @@ type GrpcNodeErrorResponse struct {
ErrorCode uint32 `json:"error_code"`
}

type grpcDescriptorCache struct {
cachedDescriptors sync.Map // method name is the key, method descriptor is the value
}

func (gdc *grpcDescriptorCache) getDescriptor(methodName string) *desc.MethodDescriptor {
if descriptor, ok := gdc.cachedDescriptors.Load(methodName); ok {
converted, success := descriptor.(*desc.MethodDescriptor) // convert to a descriptor
if success {
return converted
}
utils.LavaFormatError("Failed Converting method descriptor", nil, utils.Attribute{Key: "Method", Value: methodName})
}
return nil
}

func (gdc *grpcDescriptorCache) setDescriptor(methodName string, descriptor *desc.MethodDescriptor) {
gdc.cachedDescriptors.Store(methodName, descriptor)
}

type GrpcChainParser struct {
BaseChainParser

Expand Down Expand Up @@ -388,7 +368,7 @@ func (apil *GrpcChainListener) GetListeningAddress() string {
type GrpcChainProxy struct {
BaseChainProxy
conn grpcConnectorInterface
descriptorsCache *grpcDescriptorCache
descriptorsCache *common.SafeSyncMap[string, *desc.MethodDescriptor]
}
type grpcConnectorInterface interface {
Close()
Expand All @@ -413,7 +393,7 @@ func NewGrpcChainProxy(ctx context.Context, nConns uint, rpcProviderEndpoint lav
func newGrpcChainProxy(ctx context.Context, averageBlockTime time.Duration, parser ChainParser, conn grpcConnectorInterface, rpcProviderEndpoint lavasession.RPCProviderEndpoint) (ChainProxy, error) {
cp := &GrpcChainProxy{
BaseChainProxy: BaseChainProxy{averageBlockTime: averageBlockTime, ErrorHandler: &GRPCErrorHandler{}, ChainID: rpcProviderEndpoint.ChainID, HashedNodeUrl: chainproxy.HashURL(rpcProviderEndpoint.NodeUrls[0].Url)},
descriptorsCache: &grpcDescriptorCache{},
descriptorsCache: &common.SafeSyncMap[string, *desc.MethodDescriptor]{},
}
cp.conn = conn
if cp.conn == nil {
Expand Down Expand Up @@ -471,9 +451,12 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
descriptorSource := rpcInterfaceMessages.DescriptorSourceFromServer(cl)
svc, methodName := rpcInterfaceMessages.ParseSymbol(nodeMessage.Path)

// check if we have method descriptor already cached.
methodDescriptor := cp.descriptorsCache.getDescriptor(methodName)
if methodDescriptor == nil { // method descriptor not cached yet, need to fetch it and add to cache
// Check if we have method descriptor already cached.
// The reason we do Load and then Store here, instead of LoadOrStore:
// On the worst case scenario, where 2 threads are accessing the map at the same time, the same descriptor will be stored twice.
// It is better than the alternative, which is always creating the descriptor, since the outcome is the same.
methodDescriptor, found, _ := cp.descriptorsCache.Load(methodName)
if !found { // method descriptor not cached yet, need to fetch it and add to cache
var descriptor desc.Descriptor
if descriptor, err = descriptorSource.FindSymbol(svc); err != nil {
return nil, "", nil, utils.LavaFormatError("descriptorSource.FindSymbol", err, utils.Attribute{Key: "GUID", Value: ctx})
Expand All @@ -488,7 +471,7 @@ func (cp *GrpcChainProxy) SendNodeMsg(ctx context.Context, ch chan interface{},
}

// add the descriptor to the chainProxy cache
cp.descriptorsCache.setDescriptor(methodName, methodDescriptor)
cp.descriptorsCache.Store(methodName, methodDescriptor)
omerlavanet marked this conversation as resolved.
Show resolved Hide resolved
}

msgFactory := dynamic.NewMessageFactoryWithDefaults()
Expand Down
21 changes: 16 additions & 5 deletions protocol/chaintracker/chain_tracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ type ChainTracker struct {
blockEventsGap []time.Duration
blockTimeUpdatables map[blockTimeUpdatable]struct{}
pmetrics *metrics.ProviderMetricsManager

// initial config
averageBlockTime time.Duration
serverAddress string
}

// this function returns block hashes of the blocks: [from block - to block] inclusive. an additional specific block hash can be provided. order is sorted ascending
Expand Down Expand Up @@ -570,6 +574,16 @@ func (ct *ChainTracker) serve(ctx context.Context, listenAddr string) error {
return nil
}

func (ct *ChainTracker) StartAndServe(ctx context.Context) error {
err := ct.start(ctx, ct.averageBlockTime)
if err != nil {
return err
}

err = ct.serve(ctx, ct.serverAddress)
return err
}

func NewChainTracker(ctx context.Context, chainFetcher ChainFetcher, config ChainTrackerConfig) (chainTracker *ChainTracker, err error) {
if !rand.Initialized() {
utils.LavaFormatFatal("can't start chainTracker with nil rand source", nil)
Expand Down Expand Up @@ -598,16 +612,13 @@ func NewChainTracker(ctx context.Context, chainFetcher ChainFetcher, config Chai
startupTime: time.Now(),
pmetrics: config.Pmetrics,
pollingTimeMultiplier: time.Duration(pollingTime),
averageBlockTime: config.AverageBlockTime,
serverAddress: config.ServerAddress,
}
if chainFetcher == nil {
return nil, utils.LavaFormatError("can't start chainTracker with nil chainFetcher argument", nil)
}
chainTracker.endpoint = chainFetcher.FetchEndpoint()
err = chainTracker.start(ctx, config.AverageBlockTime)
if err != nil {
return nil, err
}

err = chainTracker.serve(ctx, config.ServerAddress)
return chainTracker, err
}
7 changes: 7 additions & 0 deletions protocol/chaintracker/chain_tracker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ func TestChainTracker(t *testing.T) {

chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
chainTracker.StartAndServe(context.Background())
require.NoError(t, err)
for _, advancement := range tt.advancements {
for i := 0; i < int(advancement); i++ {
Expand Down Expand Up @@ -218,6 +219,7 @@ func TestChainTrackerRangeOnly(t *testing.T) {

chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(tt.fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(tt.mockBlocks)}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
chainTracker.StartAndServe(context.Background())
require.NoError(t, err)
for _, advancement := range tt.advancements {
for i := 0; i < int(advancement); i++ {
Expand Down Expand Up @@ -302,6 +304,7 @@ func TestChainTrackerCallbacks(t *testing.T) {
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback, NewLatestCallback: newBlockCallback}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
chainTracker.StartAndServe(context.Background())
totalAdvancement := 0
t.Run("one long test", func(t *testing.T) {
for _, tt := range tests {
Expand Down Expand Up @@ -368,6 +371,7 @@ func TestChainTrackerFetchSpreadAcrossPollingTime(t *testing.T) {
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: localTimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)}
tracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
tracker.StartAndServe(context.Background())
// fool the tracker so it thinks blocks will come every localTimeForPollingMock (ms), and not adjust it's polling timers
for i := 0; i < 50; i++ {
tracker.AddBlockGap(localTimeForPollingMock, 1)
Expand Down Expand Up @@ -491,6 +495,7 @@ func TestChainTrackerPollingTimeUpdate(t *testing.T) {
mockChainFetcher.AdvanceBlock()
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: play.localTimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)}
tracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
tracker.StartAndServe(context.Background())
tracker.RegisterForBlockTimeUpdates(&mockTimeUpdater)
require.NoError(t, err)
// initial delay
Expand Down Expand Up @@ -555,6 +560,7 @@ func TestChainTrackerMaintainMemory(t *testing.T) {
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks), ForkCallback: forkCallback}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
chainTracker.StartAndServe(context.Background())
t.Run("one long test", func(t *testing.T) {
for _, tt := range tests {
utils.LavaFormatInfo(startedTestStr + tt.name)
Expand Down Expand Up @@ -607,6 +613,7 @@ func TestFindRequestedBlockHash(t *testing.T) {
chainTrackerConfig := chaintracker.ChainTrackerConfig{BlocksToSave: uint64(fetcherBlocks), AverageBlockTime: TimeForPollingMock, ServerBlockMemory: uint64(mockBlocks)}
chainTracker, err := chaintracker.NewChainTracker(context.Background(), mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
chainTracker.StartAndServe(context.Background())
latestBlock, onlyLatestBlockData, _, err := chainTracker.GetLatestBlockData(spectypes.LATEST_BLOCK, spectypes.LATEST_BLOCK, spectypes.NOT_APPLICABLE)
require.NoError(t, err)
require.Equal(t, currentLatestBlockInMock, latestBlock)
Expand Down
51 changes: 51 additions & 0 deletions protocol/common/safe_sync_map.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package common

import (
"sync"

"github.com/lavanet/lava/v3/utils"
)

type SafeSyncMap[K, V any] struct {
localMap sync.Map
}

func (ssm *SafeSyncMap[K, V]) Store(key K, toSet V) {
ssm.localMap.Store(key, toSet)
}

func (ssm *SafeSyncMap[K, V]) Load(key K) (ret V, ok bool, err error) {
value, ok := ssm.localMap.Load(key)
if !ok {
return ret, ok, nil
}
ret, ok = value.(V)
if !ok {
return ret, false, utils.LavaFormatError("invalid usage of syncmap, could not cast result into a PolicyUpdater", nil)
}
return ret, true, nil
}

// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
// The function returns the value that was loaded or stored.
func (ssm *SafeSyncMap[K, V]) LoadOrStore(key K, value V) (ret V, loaded bool, err error) {
actual, loaded := ssm.localMap.LoadOrStore(key, value)
if loaded {
// loaded from map
var ok bool
ret, ok = actual.(V)
if !ok {
return ret, false, utils.LavaFormatError("invalid usage of sync map, could not cast result into a PolicyUpdater", nil)
}
return ret, true, nil
}

// stored in map
return value, false, nil
}

func (ssm *SafeSyncMap[K, V]) Range(f func(key, value any) bool) {
ssm.localMap.Range(f)
}
1 change: 1 addition & 0 deletions protocol/integration/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ func createRpcProvider(t *testing.T, ctx context.Context, consumerAddress string
mockChainFetcher := NewMockChainFetcher(1000, int64(blocksToSaveChainTracker), nil)
chainTracker, err := chaintracker.NewChainTracker(ctx, mockChainFetcher, chainTrackerConfig)
require.NoError(t, err)
chainTracker.StartAndServe(ctx)
reliabilityManager := reliabilitymanager.NewReliabilityManager(chainTracker, &mockProviderStateTracker, account.Addr.String(), chainRouter, chainParser)
mockReliabilityManager := NewMockReliabilityManager(reliabilityManager)
rpcProviderServer.ServeRPCRequests(ctx, rpcProviderEndpoint, chainParser, rws, providerSessionManager, mockReliabilityManager, account.SK, nil, chainRouter, &mockProviderStateTracker, account.Addr, lavaChainID, rpcprovider.DEFAULT_ALLOWED_MISSING_CU, nil, nil, nil, false)
Expand Down
47 changes: 0 additions & 47 deletions protocol/rpcconsumer/policies_map.go

This file was deleted.

Loading
Loading