Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix!: Replace GetAllConsumerChains with lightweight version #1946

Merged
merged 5 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Replace `GetAllConsumerChains` with lightweight version
(`GetAllRegisteredConsumerChainIDs`) that doesn't call into the staking module
([\#1946](https://github.com/cosmos/interchain-security/pull/1946))
31 changes: 13 additions & 18 deletions tests/mbt/driver/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
consumerkeeper "github.com/cosmos/interchain-security/v4/x/ccv/consumer/keeper"
consumertypes "github.com/cosmos/interchain-security/v4/x/ccv/consumer/types"
providerkeeper "github.com/cosmos/interchain-security/v4/x/ccv/provider/keeper"
providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
"github.com/cosmos/interchain-security/v4/x/ccv/types"
)

Expand Down Expand Up @@ -219,11 +218,7 @@ func (s *Driver) getStateString() string {
state.WriteString("\n")

state.WriteString("Consumers Chains:\n")
consumerChains := s.providerKeeper().GetAllConsumerChains(s.providerCtx())
chainIds := make([]string, len(consumerChains))
for i, consumerChain := range consumerChains {
chainIds[i] = consumerChain.ChainId
}
chainIds := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx())
state.WriteString(strings.Join(chainIds, ", "))
state.WriteString("\n\n")

Expand Down Expand Up @@ -261,11 +256,11 @@ func (s *Driver) getChainStateString(chain ChainId) string {
if !s.isProviderChain(chain) {
// Check whether the chain is in the consumer chains on the provider

consumerChains := s.providerKeeper().GetAllConsumerChains(s.providerCtx())
consumerChainIDs := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx())

found := false
for _, consumerChain := range consumerChains {
if consumerChain.ChainId == string(chain) {
for _, consumerChainID := range consumerChainIDs {
if consumerChainID == string(chain) {
found = true
}
}
Expand Down Expand Up @@ -369,16 +364,16 @@ func (s *Driver) endAndBeginBlock(chain ChainId, timeAdvancement time.Duration)
return header
}

func (s *Driver) runningConsumers() []providertypes.Chain {
consumersOnProvider := s.providerKeeper().GetAllConsumerChains(s.providerCtx())
func (s *Driver) runningConsumerChainIDs() []ChainId {
consumerIDsOnProvider := s.providerKeeper().GetAllRegisteredConsumerChainIDs(s.providerCtx())

consumersWithIntactChannel := make([]providertypes.Chain, 0)
for _, consumer := range consumersOnProvider {
if s.path(ChainId(consumer.ChainId)).Path.EndpointA.GetChannel().State == channeltypes.CLOSED ||
s.path(ChainId(consumer.ChainId)).Path.EndpointB.GetChannel().State == channeltypes.CLOSED {
consumersWithIntactChannel := make([]ChainId, 0)
for _, consumerChainID := range consumerIDsOnProvider {
if s.path(ChainId(consumerChainID)).Path.EndpointA.GetChannel().State == channeltypes.CLOSED ||
s.path(ChainId(consumerChainID)).Path.EndpointB.GetChannel().State == channeltypes.CLOSED {
continue
}
consumersWithIntactChannel = append(consumersWithIntactChannel, consumer)
consumersWithIntactChannel = append(consumersWithIntactChannel, ChainId(consumerChainID))
}
return consumersWithIntactChannel
}
Expand Down Expand Up @@ -447,8 +442,8 @@ func (s *Driver) RequestSlash(
// DeliverAcks delivers, for each path,
// all possible acks (up to math.MaxInt many per path).
func (s *Driver) DeliverAcks() {
for _, chain := range s.runningConsumers() {
path := s.path(ChainId(chain.ChainId))
for _, chainID := range s.runningConsumerChainIDs() {
path := s.path(chainID)
path.DeliverAcks(path.Path.EndpointA.Chain.ChainID, math.MaxInt)
path.DeliverAcks(path.Path.EndpointB.Chain.ChainID, math.MaxInt)
}
Expand Down
72 changes: 36 additions & 36 deletions tests/mbt/driver/mbt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,21 +304,21 @@ func RunItfTrace(t *testing.T, path string) {
// needs a header of height H+1 to accept the packet
// so, we do two blocks, one with a very small increment,
// and then another to increment the rest of the time
runningConsumersBefore := driver.runningConsumers()
runningConsumerChainIDsBefore := driver.runningConsumerChainIDs()

driver.endAndBeginBlock("provider", 1*time.Nanosecond)
for _, consumer := range driver.runningConsumers() {
UpdateProviderClientOnConsumer(t, driver, consumer.ChainId)
for _, consumerChainID := range driver.runningConsumerChainIDs() {
UpdateProviderClientOnConsumer(t, driver, string(consumerChainID))
}
driver.endAndBeginBlock("provider", time.Duration(timeAdvancement)*time.Second-1*time.Nanosecond)

runningConsumersAfter := driver.runningConsumers()
runningConsumerChainIDsAfter := driver.runningConsumerChainIDs()

// the consumers that were running before but not after must have timed out
for _, consumer := range runningConsumersBefore {
for _, consumerChainID := range runningConsumerChainIDsBefore {
found := false
for _, consumerAfter := range runningConsumersAfter {
if consumerAfter.ChainId == consumer.ChainId {
for _, consumerChainIDAfter := range runningConsumerChainIDsAfter {
if consumerChainIDAfter == consumerChainID {
found = true
break
}
Expand All @@ -332,8 +332,8 @@ func RunItfTrace(t *testing.T, path string) {
// because setting up chains will modify timestamps
// when the coordinator is starting chains
lastTimestamps := make(map[ChainId]time.Time, len(consumers))
for _, consumer := range driver.runningConsumers() {
lastTimestamps[ChainId(consumer.ChainId)] = driver.runningTime(ChainId(consumer.ChainId))
for _, consumerChainID := range driver.runningConsumerChainIDs() {
lastTimestamps[consumerChainID] = driver.runningTime(consumerChainID)
}

driver.coordinator.CurrentTime = driver.runningTime("provider")
Expand Down Expand Up @@ -364,12 +364,12 @@ func RunItfTrace(t *testing.T, path string) {
// for all connected consumers, update the clients...
// unless it was the last consumer to be started, in which case it already has the header
// as we called driver.setupConsumer
for _, consumer := range driver.runningConsumers() {
if len(consumersToStart) > 0 && consumer.ChainId == consumersToStart[len(consumersToStart)-1].Value.(string) {
for _, consumerChainID := range driver.runningConsumerChainIDs() {
if len(consumersToStart) > 0 && string(consumerChainID) == consumersToStart[len(consumersToStart)-1].Value.(string) {
continue
}

UpdateProviderClientOnConsumer(t, driver, consumer.ChainId)
UpdateProviderClientOnConsumer(t, driver, string(consumerChainID))
}

case "EndAndBeginBlockForConsumer":
Expand Down Expand Up @@ -490,33 +490,33 @@ func RunItfTrace(t *testing.T, path string) {
t.Logf("Comparing model state to actual state...")

// compare the running consumers
modelRunningConsumers := RunningConsumers(currentModelState)
modelRunningConsumerChainIDs := RunningConsumers(currentModelState)

systemRunningConsumers := driver.runningConsumers()
actualRunningConsumers := make([]string, len(systemRunningConsumers))
for i, chain := range systemRunningConsumers {
actualRunningConsumers[i] = chain.ChainId
systemRunningConsumerChainIDs := driver.runningConsumerChainIDs()
actualRunningConsumerChainIDs := make([]string, len(systemRunningConsumerChainIDs))
for i, chainID := range systemRunningConsumerChainIDs {
actualRunningConsumerChainIDs[i] = string(chainID)
}

// sort the slices so that we can compare them
sort.Slice(modelRunningConsumers, func(i, j int) bool {
return modelRunningConsumers[i] < modelRunningConsumers[j]
sort.Slice(modelRunningConsumerChainIDs, func(i, j int) bool {
return modelRunningConsumerChainIDs[i] < modelRunningConsumerChainIDs[j]
})
sort.Slice(actualRunningConsumers, func(i, j int) bool {
return actualRunningConsumers[i] < actualRunningConsumers[j]
sort.Slice(actualRunningConsumerChainIDs, func(i, j int) bool {
return actualRunningConsumerChainIDs[i] < actualRunningConsumerChainIDs[j]
})

require.Equal(t, modelRunningConsumers, actualRunningConsumers, "Running consumers do not match")
require.Equal(t, modelRunningConsumerChainIDs, actualRunningConsumerChainIDs, "Running consumers do not match")

// check validator sets - provider current validator set should be the one from the staking keeper
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers, realAddrsToModelConsAddrs)
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumerChainIDs, realAddrsToModelConsAddrs)

// check times - sanity check that the block times match the ones from the model
CompareTimes(driver, actualRunningConsumers, currentModelState, timeOffset)
CompareTimes(driver, actualRunningConsumerChainIDs, currentModelState, timeOffset)

// check sent packets: we check that the package queues in the model and the system have the same length.
for _, consumer := range actualRunningConsumers {
ComparePacketQueues(t, driver, currentModelState, consumer, timeOffset)
for _, consumerChainID := range actualRunningConsumerChainIDs {
ComparePacketQueues(t, driver, currentModelState, consumerChainID, timeOffset)
}
// compare that the sent packets on the proider match the model
CompareSentPacketsOnProvider(driver, currentModelState, timeOffset)
Expand All @@ -526,8 +526,8 @@ func RunItfTrace(t *testing.T, path string) {
CompareJailedValidators(driver, currentModelState, timeOffset, addressMap)

// for all newly sent vsc packets, figure out which vsc id in the model they correspond to
for _, consumer := range actualRunningConsumers {
actualPackets := driver.packetQueue(PROVIDER, ChainId(consumer))
for _, consumerChainID := range actualRunningConsumerChainIDs {
actualPackets := driver.packetQueue(PROVIDER, ChainId(consumerChainID))
actualNewPackets := make([]types.ValidatorSetChangePacketData, 0)
for _, packet := range actualPackets {

Expand All @@ -543,7 +543,7 @@ func RunItfTrace(t *testing.T, path string) {
actualNewPackets = append(actualNewPackets, packetData)
}

modelPackets := PacketQueue(currentModelState, PROVIDER, consumer)
modelPackets := PacketQueue(currentModelState, PROVIDER, consumerChainID)
newModelVscIds := make([]uint64, 0)
for _, packet := range modelPackets {
modelVscId := uint64(packet.Value.(itf.MapExprType)["value"].Value.(itf.MapExprType)["id"].Value.(int64))
Expand Down Expand Up @@ -781,15 +781,15 @@ func CompareValSet(modelValSet map[string]itf.Expr, systemValSet map[string]int6
}

func CompareSentPacketsOnProvider(driver *Driver, currentModelState map[string]itf.Expr, timeOffset time.Time) {
for _, consumer := range driver.runningConsumers() {
vscSendTimestamps := driver.providerKeeper().GetAllVscSendTimestamps(driver.providerCtx(), consumer.ChainId)
for _, consumerChainID := range driver.runningConsumerChainIDs() {
vscSendTimestamps := driver.providerKeeper().GetAllVscSendTimestamps(driver.providerCtx(), string(consumerChainID))

actualVscSendTimestamps := make([]time.Time, 0)
for _, vscSendTimestamp := range vscSendTimestamps {
actualVscSendTimestamps = append(actualVscSendTimestamps, vscSendTimestamp.Timestamp)
}

modelVscSendTimestamps := VscSendTimestamps(currentModelState, consumer.ChainId)
modelVscSendTimestamps := VscSendTimestamps(currentModelState, string(consumerChainID))

for i, modelVscSendTimestamp := range modelVscSendTimestamps {
actualTimeWithOffset := actualVscSendTimestamps[i].Unix() - timeOffset.Unix()
Expand All @@ -798,7 +798,7 @@ func CompareSentPacketsOnProvider(driver *Driver, currentModelState map[string]i
modelVscSendTimestamp,
actualTimeWithOffset,
"Vsc send timestamps do not match for consumer %v",
consumer.ChainId,
consumerChainID,
)
}
}
Expand Down Expand Up @@ -852,9 +852,9 @@ func (s *Stats) EnterStats(driver *Driver) {

// max number of in-flight packets
inFlightPackets := 0
for _, consumer := range driver.runningConsumers() {
inFlightPackets += len(driver.packetQueue(PROVIDER, ChainId(consumer.ChainId)))
inFlightPackets += len(driver.packetQueue(ChainId(consumer.ChainId), PROVIDER))
for _, consumerChainID := range driver.runningConsumerChainIDs() {
inFlightPackets += len(driver.packetQueue(PROVIDER, consumerChainID))
inFlightPackets += len(driver.packetQueue(consumerChainID, PROVIDER))
}
if inFlightPackets > s.maxNumInFlightPackets {
s.maxNumInFlightPackets = inFlightPackets
Expand Down
10 changes: 5 additions & 5 deletions x/ccv/provider/keeper/distribution.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) {
}

// Iterate over all registered consumer chains
for _, consumer := range k.GetAllConsumerChains(ctx) {
for _, consumerChainID := range k.GetAllRegisteredConsumerChainIDs(ctx) {
// transfer the consumer rewards to the distribution module account
// note that the rewards transferred are only consumer whitelisted denoms
rewardsCollected, err := k.TransferConsumerRewardsToDistributionModule(ctx, consumer.ChainId)
rewardsCollected, err := k.TransferConsumerRewardsToDistributionModule(ctx, consumerChainID)
if err != nil {
k.Logger(ctx).Error(
"fail to transfer rewards to distribution module for chain %s: %s",
consumer.ChainId,
consumerChainID,
err,
)
continue
Expand All @@ -101,7 +101,7 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) {
// temporary workaround to keep CanWithdrawInvariant happy
// general discussions here: https://github.com/cosmos/cosmos-sdk/issues/2906#issuecomment-441867634
feePool := k.distributionKeeper.GetFeePool(ctx)
if k.ComputeConsumerTotalVotingPower(ctx, consumer.ChainId) == 0 {
if k.ComputeConsumerTotalVotingPower(ctx, consumerChainID) == 0 {
feePool.CommunityPool = feePool.CommunityPool.Add(rewardsCollectedDec...)
k.distributionKeeper.SetFeePool(ctx, feePool)
return
Expand All @@ -116,7 +116,7 @@ func (k Keeper) AllocateTokens(ctx sdk.Context) {
// allocate tokens to consumer validators
feeAllocated := k.AllocateTokensToConsumerValidators(
ctx,
consumer.ChainId,
consumerChainID,
feeMultiplier,
)
remaining = remaining.Sub(feeAllocated)
Expand Down
36 changes: 20 additions & 16 deletions x/ccv/provider/keeper/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,47 +108,51 @@ func (k Keeper) InitGenesis(ctx sdk.Context, genState *types.GenesisState) {
// ExportGenesis returns the CCV provider module's exported genesis
func (k Keeper) ExportGenesis(ctx sdk.Context) *types.GenesisState {
// get a list of all registered consumer chains
registeredChains := k.GetAllConsumerChains(ctx)
registeredChainIDs := k.GetAllRegisteredConsumerChainIDs(ctx)

var exportedVscSendTimestamps []types.ExportedVscSendTimestamp
// export states for each consumer chains
var consumerStates []types.ConsumerState
for _, chain := range registeredChains {
gen, found := k.GetConsumerGenesis(ctx, chain.ChainId)
for _, chainID := range registeredChainIDs {
// no need for the second return value of GetConsumerClientId
// as GetAllRegisteredConsumerChainIDs already iterated through
// the entire prefix range
clientID, _ := k.GetConsumerClientId(ctx, chainID)
gen, found := k.GetConsumerGenesis(ctx, chainID)
if !found {
panic(fmt.Errorf("cannot find genesis for consumer chain %s with client %s", chain.ChainId, chain.ClientId))
panic(fmt.Errorf("cannot find genesis for consumer chain %s with client %s", chainID, clientID))
}

// initial consumer chain states
cs := types.ConsumerState{
ChainId: chain.ChainId,
ClientId: chain.ClientId,
ChainId: chainID,
ClientId: clientID,
ConsumerGenesis: gen,
UnbondingOpsIndex: k.GetAllUnbondingOpIndexes(ctx, chain.ChainId),
UnbondingOpsIndex: k.GetAllUnbondingOpIndexes(ctx, chainID),
}

// try to find channel id for the current consumer chain
channelId, found := k.GetChainToChannel(ctx, chain.ChainId)
channelId, found := k.GetChainToChannel(ctx, chainID)
if found {
cs.ChannelId = channelId
cs.InitialHeight, found = k.GetInitChainHeight(ctx, chain.ChainId)
cs.InitialHeight, found = k.GetInitChainHeight(ctx, chainID)
if !found {
panic(fmt.Errorf("cannot find init height for consumer chain %s", chain.ChainId))
panic(fmt.Errorf("cannot find init height for consumer chain %s", chainID))
}
cs.SlashDowntimeAck = k.GetSlashAcks(ctx, chain.ChainId)
cs.SlashDowntimeAck = k.GetSlashAcks(ctx, chainID)
}

cs.PendingValsetChanges = k.GetPendingVSCPackets(ctx, chain.ChainId)
cs.PendingValsetChanges = k.GetPendingVSCPackets(ctx, chainID)
consumerStates = append(consumerStates, cs)

vscSendTimestamps := k.GetAllVscSendTimestamps(ctx, chain.ChainId)
exportedVscSendTimestamps = append(exportedVscSendTimestamps, types.ExportedVscSendTimestamp{ChainId: chain.ChainId, VscSendTimestamps: vscSendTimestamps})
vscSendTimestamps := k.GetAllVscSendTimestamps(ctx, chainID)
exportedVscSendTimestamps = append(exportedVscSendTimestamps, types.ExportedVscSendTimestamp{ChainId: chainID, VscSendTimestamps: vscSendTimestamps})
}

// ConsumerAddrsToPrune are added only for registered consumer chains
consumerAddrsToPrune := []types.ConsumerAddrsToPrune{}
for _, chain := range registeredChains {
consumerAddrsToPrune = append(consumerAddrsToPrune, k.GetAllConsumerAddrsToPrune(ctx, chain.ChainId)...)
for _, chainID := range registeredChainIDs {
consumerAddrsToPrune = append(consumerAddrsToPrune, k.GetAllConsumerAddrsToPrune(ctx, chainID)...)
}

params := k.GetParams(ctx)
Expand Down
Loading
Loading