Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor/test: 1093 continued #1104

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 23 additions & 66 deletions tests/integration/throttle.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"

icstestingutils "github.com/cosmos/interchain-security/v3/testutil/ibc_testing"
"github.com/cosmos/interchain-security/v3/x/ccv/provider"
providertypes "github.com/cosmos/interchain-security/v3/x/ccv/provider/types"
ccvtypes "github.com/cosmos/interchain-security/v3/x/ccv/types"
)
Expand Down Expand Up @@ -314,15 +315,9 @@ func (s *CCVTestSuite) TestPacketSpam() {

// Recv 500 packets from consumer to provider in same block
for _, packet := range packets {
consumerPacketData := ccvtypes.ConsumerPacketData{}
consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{}
err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData)
if err == nil {
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
} else {
ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketDataV1.GetSlashPacketData().FromV1())
}
consumerPacketData, err := provider.UnmarshalConsumerPacket(packet) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
}

// Execute block to handle packets in endblock
Expand Down Expand Up @@ -374,15 +369,9 @@ func (s *CCVTestSuite) TestDoubleSignDoesNotAffectThrottling() {

// Recv 500 packets from consumer to provider in same block
for _, packet := range packets {
consumerPacketData := ccvtypes.ConsumerPacketData{}
consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{}
err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData)
if err == nil {
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
} else {
ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketDataV1.GetSlashPacketData().FromV1())
}
consumerPacketData, err := provider.UnmarshalConsumerPacket(packet) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
}

// Execute block to handle packets in endblock
Expand Down Expand Up @@ -476,18 +465,10 @@ func (s *CCVTestSuite) TestQueueOrdering() {

// Recv 500 packets from consumer to provider in same block
for i, packet := range packets {
consumerPacketData := ccvtypes.ConsumerPacketData{}
consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{}
err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData)
if err != nil {
ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1)
consumerPacketData = ccvtypes.ConsumerPacketData{
Type: consumerPacketDataV1.Type,
Data: &ccvtypes.ConsumerPacketData_SlashPacketData{
SlashPacketData: consumerPacketDataV1.GetSlashPacketData().FromV1(),
},
}
}

consumerPacketData, err := provider.UnmarshalConsumerPacket(packet) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)

// Type depends on index packets were appended from above
if (i+5)%10 == 0 {
vscMaturedPacketData := consumerPacketData.GetVscMaturedPacketData()
Expand Down Expand Up @@ -700,15 +681,9 @@ func (s *CCVTestSuite) TestSlashSameValidator() {

// Recv and queue all slash packets.
for _, packet := range packets {
consumerPacketData := ccvtypes.ConsumerPacketData{}
consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{}
err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData)
if err == nil {
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
} else {
ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketDataV1.GetSlashPacketData().FromV1())
}
consumerPacketData, err := provider.UnmarshalConsumerPacket(packet) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
}

// We should have 6 pending slash packet entries queued.
Expand Down Expand Up @@ -767,15 +742,9 @@ func (s CCVTestSuite) TestSlashAllValidators() { //nolint:govet // this is a tes

// Recv and queue all slash packets.
for _, packet := range packets {
consumerPacketData := ccvtypes.ConsumerPacketData{}
consumerPacketDataV1 := ccvtypes.ConsumerPacketDataV1{}
err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacketData)
if err == nil {
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
} else {
ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &consumerPacketDataV1)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketDataV1.GetSlashPacketData().FromV1())
}
consumerPacketData, err := provider.UnmarshalConsumerPacket(packet) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
}

// We should have 24 pending slash packet entries queued.
Expand Down Expand Up @@ -820,15 +789,9 @@ func (s *CCVTestSuite) TestLeadingVSCMaturedAreDequeued() {
ibcSeqNum := uint64(i)
packet := s.constructSlashPacketFromConsumer(*bundle,
*s.providerChain.Vals.Validators[0], stakingtypes.Infraction_INFRACTION_DOWNTIME, ibcSeqNum)
packetData := ccvtypes.ConsumerPacketData{}
packetDataV1 := ccvtypes.ConsumerPacketDataV1{}
err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &packetData)
if err == nil {
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *packetData.GetSlashPacketData())
} else {
ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &packetDataV1)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *packetDataV1.GetSlashPacketData().FromV1())
}
consumerPacketData, err := provider.UnmarshalConsumerPacket(packet) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
}
}

Expand Down Expand Up @@ -916,15 +879,9 @@ func (s *CCVTestSuite) TestVscMaturedHandledPerBlockLimit() {
ibcSeqNum := uint64(i)
packet := s.constructSlashPacketFromConsumer(*bundle,
*s.providerChain.Vals.Validators[0], stakingtypes.Infraction_INFRACTION_DOWNTIME, ibcSeqNum)
packetData := ccvtypes.ConsumerPacketData{}
packetDataV1 := ccvtypes.ConsumerPacketDataV1{}
err := ccvtypes.ModuleCdc.UnmarshalJSON(packet.GetData(), &packetData)
if err == nil {
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *packetData.GetSlashPacketData())
} else {
ccvtypes.ModuleCdc.MustUnmarshalJSON(packet.GetData(), &packetDataV1)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *packetDataV1.GetSlashPacketData().FromV1())
}
consumerPacketData, err := provider.UnmarshalConsumerPacket(packet) // Same func used by provider's OnRecvPacket
s.Require().NoError(err)
providerKeeper.OnRecvSlashPacket(s.providerCtx(), packet, *consumerPacketData.GetSlashPacketData())
}
}

Expand Down
60 changes: 32 additions & 28 deletions x/ccv/provider/ibc_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,39 +174,16 @@ func (am AppModule) OnRecvPacket(
packet channeltypes.Packet,
_ sdk.AccAddress,
) ibcexported.Acknowledgement {
var (
ack ibcexported.Acknowledgement
consumerPacket ccv.ConsumerPacketData
)

// unmarshall consumer packet
if err := ccv.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacket); err != nil {
// retry for v1 packet type
var v1Packet ccv.ConsumerPacketDataV1
errV1 := ccv.ModuleCdc.UnmarshalJSON(packet.GetData(), &v1Packet)
if errV1 != nil {
errAck := ccv.NewErrorAcknowledgementWithLog(ctx, fmt.Errorf("cannot unmarshal CCV packet data"))
ack = &errAck
return ack
}

if v1Packet.Type == ccv.VscMaturedPacket {
errAck := ccv.NewErrorAcknowledgementWithLog(ctx, fmt.Errorf("unexpected VSCMaturedPacket packet type"))
ack = &errAck
return ack
}

consumerPacket = ccv.ConsumerPacketData{
Type: v1Packet.Type,
Data: &ccv.ConsumerPacketData_SlashPacketData{
SlashPacketData: v1Packet.GetSlashPacketData().FromV1(),
},
}
consumerPacket, err := UnmarshalConsumerPacket(packet)
if err != nil {
errAck := ccv.NewErrorAcknowledgementWithLog(ctx, err)
return &errAck
}

// TODO: call ValidateBasic method on consumer packet data
// See: https://github.com/cosmos/interchain-security/issues/634

var ack ibcexported.Acknowledgement
switch consumerPacket.Type {
case ccv.VscMaturedPacket:
// handle VSCMaturedPacket
Expand All @@ -230,6 +207,33 @@ func (am AppModule) OnRecvPacket(
return ack
}

func UnmarshalConsumerPacket(packet channeltypes.Packet) (consumerPacket ccv.ConsumerPacketData, err error) {
// First try unmarshaling into ccv.ConsumerPacketData type
if err := ccv.ModuleCdc.UnmarshalJSON(packet.GetData(), &consumerPacket); err != nil {
// If failed, packet should be a v1 slash packet, retry for ConsumerPacketDataV1 packet type
var v1Packet ccv.ConsumerPacketDataV1
errV1 := ccv.ModuleCdc.UnmarshalJSON(packet.GetData(), &v1Packet)
if errV1 != nil {
// If neither worked, return error
return ccv.ConsumerPacketData{}, errV1
}

// VSC matured packets should not be unmarshaled as v1 packets
if v1Packet.Type == ccv.VscMaturedPacket {
return ccv.ConsumerPacketData{}, fmt.Errorf("VSC matured packets should be correctly unmarshaled")
}

// Convert from v1 packet type
consumerPacket = ccv.ConsumerPacketData{
Type: v1Packet.Type,
Data: &ccv.ConsumerPacketData_SlashPacketData{
SlashPacketData: v1Packet.GetSlashPacketData().FromV1(),
},
}
}
return consumerPacket, nil
}

// OnAcknowledgementPacket implements the IBCModule interface
func (am AppModule) OnAcknowledgementPacket(
ctx sdk.Context,
Expand Down
60 changes: 60 additions & 0 deletions x/ccv/provider/ibc_module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
authtypes "github.com/cosmos/cosmos-sdk/x/auth/types"
capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"

"github.com/cosmos/ibc-go/v7/modules/core/02-client/types"
conntypes "github.com/cosmos/ibc-go/v7/modules/core/03-connection/types"
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
host "github.com/cosmos/ibc-go/v7/modules/core/24-host"
Expand Down Expand Up @@ -338,3 +339,62 @@ func TestOnChanOpenConfirm(t *testing.T) {
ctrl.Finish()
}
}

func TestUnmarshalConsumerPacket(t *testing.T) {
testCases := []struct {
name string
packet channeltypes.Packet
expectedPacketData ccv.ConsumerPacketData
}{
{
name: "vsc matured",
packet: channeltypes.NewPacket(
ccv.ConsumerPacketData{
Type: ccv.VscMaturedPacket,
Data: &ccv.ConsumerPacketData_VscMaturedPacketData{
VscMaturedPacketData: &ccv.VSCMaturedPacketData{
ValsetUpdateId: 420,
},
},
}.GetBytes(),
342, "sourcePort", "sourceChannel", "destinationPort", "destinationChannel", types.Height{}, 0,
),
expectedPacketData: ccv.ConsumerPacketData{
Type: ccv.VscMaturedPacket,
Data: &ccv.ConsumerPacketData_VscMaturedPacketData{
VscMaturedPacketData: &ccv.VSCMaturedPacketData{
ValsetUpdateId: 420,
},
},
},
},
{
name: "slash packet",
packet: channeltypes.NewPacket(
ccv.ConsumerPacketData{
Type: ccv.SlashPacket,
Data: &ccv.ConsumerPacketData_SlashPacketData{
SlashPacketData: &ccv.SlashPacketData{
ValsetUpdateId: 789,
},
},
}.GetBytes(), // Note packet data is converted to v1 bytes here
342, "sourcePort", "sourceChannel", "destinationPort", "destinationChannel", types.Height{}, 0,
),
expectedPacketData: ccv.ConsumerPacketData{
Type: ccv.SlashPacket,
Data: &ccv.ConsumerPacketData_SlashPacketData{
SlashPacketData: &ccv.SlashPacketData{
ValsetUpdateId: 789,
},
},
},
},
}

for _, tc := range testCases {
actualConsumerPacketData, err := provider.UnmarshalConsumerPacket(tc.packet)
require.NoError(t, err)
require.Equal(t, tc.expectedPacketData, actualConsumerPacketData)
}
}