Skip to content

Commit

Permalink
test: Ports key assignment to the driver on the PSS feature branch (#…
Browse files Browse the repository at this point in the history
…1628)

* Port key assignment to MBT driver

* Add comment and make var names clearer
  • Loading branch information
p-offtermatt authored Feb 8, 2024
1 parent 6e07565 commit e0491ed
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 30 deletions.
15 changes: 14 additions & 1 deletion tests/mbt/driver/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
abcitypes "github.com/cometbft/cometbft/abci/types"
cmttypes "github.com/cometbft/cometbft/types"

"github.com/cometbft/cometbft/proto/tendermint/crypto"
appConsumer "github.com/cosmos/interchain-security/v4/app/consumer"
appProvider "github.com/cosmos/interchain-security/v4/app/provider"
simibc "github.com/cosmos/interchain-security/v4/testutil/simibc"
Expand Down Expand Up @@ -123,9 +124,13 @@ func (s *Driver) consumerPower(i int64, chain ChainId) (int64, error) {
return v.Power, nil
}

func (s *Driver) stakingValidator(i int64) (stakingtypes.Validator, bool) {
return s.providerStakingKeeper().GetValidator(s.ctx(PROVIDER), s.validator(i))
}

// providerPower returns the power(=number of bonded tokens) of the i-th validator on the provider.
func (s *Driver) providerPower(i int64) (int64, error) {
v, found := s.providerStakingKeeper().GetValidator(s.ctx(PROVIDER), s.validator(i))
v, found := s.stakingValidator(i)
if !found {
return 0, fmt.Errorf("validator with id %v not found on provider", i)
} else {
Expand Down Expand Up @@ -370,6 +375,14 @@ func (s *Driver) setTime(chain ChainId, newTime time.Time) {
testChain.App.BeginBlock(abcitypes.RequestBeginBlock{Header: testChain.CurrentHeader})
}

func (s *Driver) AssignKey(chain ChainId, valIndex int64, key crypto.PublicKey) error {
stakingVal, found := s.stakingValidator(valIndex)
if !found {
return fmt.Errorf("validator with id %v not found on provider", valIndex)
}
return s.providerKeeper().AssignConsumerKey(s.providerCtx(), string(chain), stakingVal, key)
}

// DeliverPacketToConsumer delivers a packet from the provider to the given consumer recipient.
// It updates the client before delivering the packet.
// Since the channel is ordered, the packet that is delivered is the first packet in the outbox.
Expand Down
3 changes: 2 additions & 1 deletion tests/mbt/driver/generate_more_traces.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ go run ./... -modelPath=../model/ccv_boundeddrift.qnt -step stepBoundedDrift -in
echo "Generating synced traces with maturations"
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -invariant CanReceiveMaturations -traceFolder traces/sync_mat -numTraces 20 -numSteps 300 -numSamples 20
echo "Generating long synced traces without invariants"
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -traceFolder traces/sync_noinv -numTraces 20 -numSteps 500 -numSamples 1
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -traceFolder traces/sync_noinv -numTraces 20 -numSteps 500 -numSamples 1
go run ./... -modelPath=../model/ccv_boundeddrift.qnt --step stepBoundedDriftKeyAssignment --traceFolder traces/bound_key -numTraces 20 -numSteps 100 -numSamples 20
3 changes: 2 additions & 1 deletion tests/mbt/driver/generate_traces.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ go run ./... -modelPath=../model/ccv_boundeddrift.qnt -step stepBoundedDrift -in
echo "Generating synced traces with maturations"
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -invariant CanReceiveMaturations -traceFolder traces/sync_mat -numTraces 1 -numSteps 300 -numSamples 20
echo "Generating long synced traces without invariants"
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -traceFolder traces/sync_noinv -numTraces 1 -numSteps 500 -numSamples 1
go run ./... -modelPath=../model/ccv_sync.qnt -init initSync -step stepSync -traceFolder traces/sync_noinv -numTraces 1 -numSteps 500 -numSamples 1
go run ./... -modelPath=../model/ccv_boundeddrift.qnt --step stepBoundedDriftKeyAssignment --traceFolder traces/bound_key -numTraces 1 -numSteps 100 -numSamples 20
129 changes: 102 additions & 27 deletions tests/mbt/driver/mbt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@ import (
"github.com/kylelemons/godebug/pretty"
"github.com/stretchr/testify/require"

sdktypes "github.com/cosmos/cosmos-sdk/types"

cmttypes "github.com/cometbft/cometbft/types"

tmencoding "github.com/cometbft/cometbft/crypto/encoding"
"github.com/cosmos/interchain-security/v4/testutil/integration"

sdktypes "github.com/cosmos/cosmos-sdk/types"

providertypes "github.com/cosmos/interchain-security/v4/x/ccv/provider/types"
)

Expand Down Expand Up @@ -69,6 +72,7 @@ func TestMBT(t *testing.T) {
t.Logf("Number of sent packets: %v", stats.numSentPackets)
t.Logf("Number of blocks: %v", stats.numBlocks)
t.Logf("Number of transactions: %v", stats.numTxs)
t.Logf("Number of key assignments: %v", stats.numKeyAssignments)
t.Logf("Average summed block time delta passed per trace: %v", stats.totalBlockTimePassedPerTrace/time.Duration(numTraces))
}

Expand Down Expand Up @@ -117,6 +121,25 @@ func RunItfTrace(t *testing.T, path string) {

t.Log("Chains are: ", chains)

// generate keys that can be assigned on consumers, according to the ConsumerAddresses in the trace
consumerAddressesExpr := params["ConsumerAddresses"].Value.(itf.ListExprType)

_, _, consumerPrivVals, err := integration.CreateValidators(len(consumerAddressesExpr))
require.NoError(t, err, "Error creating consumer signers")

// consumerAddrNames are the human readable names of consumer addresses in the model
// "realAddrs" are the addresses of the consumer keys on chain
// these maps relate the consumerAddrNames to the priv validators (from which one can get the real address)
// and from the real ddresses to the consumerAddrNames to allow converting between the two easily
consumerAddrNamesToPrivVals := make(map[string]cmttypes.PrivValidator, len(consumerAddressesExpr))
realAddrsToModelConsAddrs := make(map[string]string, len(consumerAddressesExpr))
i := 0
for address, privVal := range consumerPrivVals {
consumerAddrNamesToPrivVals[consumerAddressesExpr[i].Value.(string)] = privVal
realAddrsToModelConsAddrs[address] = consumerAddressesExpr[i].Value.(string)
i++
}

// create params struct
vscTimeout := time.Duration(params["VscTimeout"].Value.(int64)) * time.Second

Expand Down Expand Up @@ -145,6 +168,15 @@ func RunItfTrace(t *testing.T, path string) {
valSet, addressMap, signers, err := CreateValSet(initialValSet)
require.NoError(t, err, "Error creating validator set")

// get the set of signers for consumers: the validator signers, plus signers for the assignable addresses
consumerSigners := make(map[string]cmttypes.PrivValidator, 0)
for consAddr, consPrivVal := range consumerPrivVals {
consumerSigners[consAddr] = consPrivVal
}
for consAddr, signer := range signers {
consumerSigners[consAddr] = signer
}

// get a slice of validators in the right order
nodes := make([]*cmttypes.Validator, len(valNames))
for i, valName := range valNames {
Expand Down Expand Up @@ -211,6 +243,10 @@ func RunItfTrace(t *testing.T, path string) {
// and then increment the rest of the time
runningConsumersBefore := driver.runningConsumers()
driver.endAndBeginBlock("provider", 1*time.Nanosecond)
for _, consumer := range driver.runningConsumers() {
UpdateProviderClientOnConsumer(t, driver, consumer.ChainId)
}

driver.endAndBeginBlock("provider", time.Duration(timeAdvancement)*time.Second-1*time.Nanosecond)
runningConsumersAfter := driver.runningConsumers()

Expand Down Expand Up @@ -243,7 +279,7 @@ func RunItfTrace(t *testing.T, path string) {
consumer.Value.(string),
modelParams,
driver.providerChain().Vals,
signers,
consumerSigners,
nodes,
valNames,
driver.providerChain(),
Expand All @@ -268,11 +304,8 @@ func RunItfTrace(t *testing.T, path string) {
if len(consumersToStart) > 0 && consumer.ChainId == consumersToStart[len(consumersToStart)-1].Value.(string) {
continue
}
consumerChainId := consumer.ChainId

driver.path(ChainId(consumerChainId)).AddClientHeader(PROVIDER, driver.providerHeader())
err := driver.path(ChainId(consumerChainId)).UpdateClient(consumerChainId, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChainId, err)
UpdateProviderClientOnConsumer(t, driver, consumer.ChainId)
}

case "EndAndBeginBlockForConsumer":
Expand All @@ -286,13 +319,12 @@ func RunItfTrace(t *testing.T, path string) {
_ = headerBefore

driver.endAndBeginBlock(ChainId(consumerChain), 1*time.Nanosecond)
UpdateConsumerClientOnProvider(t, driver, consumerChain)

driver.endAndBeginBlock(ChainId(consumerChain), time.Duration(timeAdvancement)*time.Second-1*time.Nanosecond)

// update the client on the provider
consumerHeader := driver.chain(ChainId(consumerChain)).LastHeader
driver.path(ChainId(consumerChain)).AddClientHeader(consumerChain, consumerHeader)
err := driver.path(ChainId(consumerChain)).UpdateClient(PROVIDER, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChain, err)
UpdateConsumerClientOnProvider(t, driver, consumerChain)

case "DeliverVscPacket":
consumerChain := lastAction["consumerChain"].Value.(string)
Expand Down Expand Up @@ -328,8 +360,26 @@ func RunItfTrace(t *testing.T, path string) {
expectError = false
driver.DeliverPacketFromConsumer(ChainId(consumerChain), expectError)
}
default:
case "KeyAssignment":
consumerChain := lastAction["consumerChain"].Value.(string)
node := lastAction["validator"].Value.(string)
consumerAddr := lastAction["consumerAddr"].Value.(string)

t.Log("KeyAssignment", consumerChain, node, consumerAddr)
stats.numKeyAssignments++

valIndex := getIndexOfString(node, valNames)
assignedPrivVal := consumerAddrNamesToPrivVals[consumerAddr]
assignedKey, err := assignedPrivVal.GetPubKey()
require.NoError(t, err, "Error getting pubkey")

protoPubKey, err := tmencoding.PubKeyToProto(assignedKey)
require.NoError(t, err, "Error converting pubkey to proto")

error := driver.AssignKey(ChainId(consumerChain), int64(valIndex), protoPubKey)
require.NoError(t, error, "Error assigning key")

default:
log.Fatalf("Error loading trace file %s, step %v: do not know action type %s",
path, index, actionKind)
}
Expand Down Expand Up @@ -364,7 +414,7 @@ func RunItfTrace(t *testing.T, path string) {
require.Equal(t, modelRunningConsumers, actualRunningConsumers, "Running consumers do not match")

// check validator sets - provider current validator set should be the one from the staking keeper
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers)
CompareValidatorSets(t, driver, currentModelState, actualRunningConsumers, realAddrsToModelConsAddrs)

// check times - sanity check that the block times match the ones from the model
CompareTimes(driver, actualRunningConsumers, currentModelState, timeOffset)
Expand All @@ -383,7 +433,27 @@ func RunItfTrace(t *testing.T, path string) {
t.Log("🟢 Trace is ok!")
}

func CompareValidatorSets(t *testing.T, driver *Driver, currentModelState map[string]itf.Expr, consumers []string) {
func UpdateProviderClientOnConsumer(t *testing.T, driver *Driver, consumerChainId string) {
driver.path(ChainId(consumerChainId)).AddClientHeader(PROVIDER, driver.providerHeader())
err := driver.path(ChainId(consumerChainId)).UpdateClient(consumerChainId, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChainId, err)
}

func UpdateConsumerClientOnProvider(t *testing.T, driver *Driver, consumerChain string) {
consumerHeader := driver.chain(ChainId(consumerChain)).LastHeader
driver.path(ChainId(consumerChain)).AddClientHeader(consumerChain, consumerHeader)
err := driver.path(ChainId(consumerChain)).UpdateClient(PROVIDER, false)
require.True(t, err == nil, "Error updating client from %v on provider: %v", consumerChain, err)
}

func CompareValidatorSets(
t *testing.T,
driver *Driver,
currentModelState map[string]itf.Expr,
consumers []string,
// a map from real addresses to the names of those consumer addresses in the model
keyAddrsToModelConsAddrName map[string]string,
) {
t.Helper()
modelValSet := ValidatorSet(currentModelState, "provider")

Expand All @@ -407,23 +477,28 @@ func CompareValidatorSets(t *testing.T, driver *Driver, currentModelState map[st
pubkey, err := val.ConsPubKey()
require.NoError(t, err, "Error getting pubkey")

consAddr := providertypes.NewConsumerConsAddress(sdktypes.ConsAddress(pubkey.Address().Bytes()))
consAddrModelName, ok := keyAddrsToModelConsAddrName[pubkey.Address().String()]
if ok { // the node has a key assigned, use the name of the consumer address in the model
consumerCurValSet[consAddrModelName] = val.Power
} else { // the node doesn't have a key assigned yet, get the validator moniker
consAddr := providertypes.NewConsumerConsAddress(sdktypes.ConsAddress(pubkey.Address().Bytes()))

// the consumer vals right now are CrossChainValidators, for which we don't know their mnemonic
// so we need to find the mnemonic of the consumer val now to enter it by name in the map
// the consumer vals right now are CrossChainValidators, for which we don't know their mnemonic
// so we need to find the mnemonic of the consumer val now to enter it by name in the map

// get the address on the provider that corresponds to the consumer address
providerConsAddr, found := driver.providerKeeper().GetValidatorByConsumerAddr(driver.providerCtx(), consumer, consAddr)
if !found {
providerConsAddr = providertypes.NewProviderConsAddress(consAddr.Address)
}
// get the address on the provider that corresponds to the consumer address
providerConsAddr, found := driver.providerKeeper().GetValidatorByConsumerAddr(driver.providerCtx(), consumer, consAddr)
if !found {
providerConsAddr = providertypes.NewProviderConsAddress(consAddr.Address)
}

// get the validator for that address on the provider
providerVal, found := driver.providerStakingKeeper().GetValidatorByConsAddr(driver.providerCtx(), providerConsAddr.Address)
require.True(t, found, "Error getting provider validator")
// get the validator for that address on the provider
providerVal, found := driver.providerStakingKeeper().GetValidatorByConsAddr(driver.providerCtx(), providerConsAddr.Address)
require.True(t, found, "Error getting provider validator")

// use the moniker of that validator
consumerCurValSet[providerVal.GetMoniker()] = val.Power
// use the moniker of that validator
consumerCurValSet[providerVal.GetMoniker()] = val.Power
}
}
require.NoError(t, CompareValSet(modelValSet, consumerCurValSet), "Validator sets do not match for consumer %v", consumer)
}
Expand Down
2 changes: 2 additions & 0 deletions tests/mbt/driver/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,6 @@ type Stats struct {
numTxs int

totalBlockTimePassedPerTrace time.Duration

numKeyAssignments int
}

0 comments on commit e0491ed

Please sign in to comment.