Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ics29:fix: store source address for query later on WriteAck #912

Merged
merged 8 commits into from
Feb 23, 2022
11 changes: 6 additions & 5 deletions modules/apps/29-fee/ibc_module.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,15 @@ func (im IBCModule) OnRecvPacket(

ack := im.app.OnRecvPacket(ctx, packet, relayer)

forwardRelayer, found := im.keeper.GetCounterpartyAddress(ctx, relayer.String(), packet.DestinationChannel)

// incase of async aknowledgement (ack == nil) store the ForwardRelayer address for use later
if ack == nil && found {
im.keeper.SetForwardRelayerAddress(ctx, channeltypes.NewPacketId(packet.GetSourceChannel(), packet.GetSourcePort(), packet.GetSequence()), forwardRelayer)
// incase of async aknowledgement (ack == nil) store the relayer address for use later during async WriteAcknowledgement
if ack == nil {
im.keeper.SetRelayerAddressForAsyncAck(ctx, channeltypes.NewPacketId(packet.GetSourceChannel(), packet.GetSourcePort(), packet.GetSequence()), relayer.String())
return nil
}

// if forwardRelayer is not found we refund recv_fee
forwardRelayer, _ := im.keeper.GetCounterpartyAddress(ctx, relayer.String(), packet.GetSourceChannel())

return types.NewIncentivizedAcknowledgement(forwardRelayer, ack.Acknowledgement(), ack.Success())
}

Expand Down
2 changes: 1 addition & 1 deletion modules/apps/29-fee/keeper/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func (k Keeper) InitGenesis(ctx sdk.Context, state types.GenesisState) {
}

for _, forwardAddr := range state.ForwardRelayers {
k.SetForwardRelayerAddress(ctx, forwardAddr.PacketId, forwardAddr.Address)
k.SetRelayerAddressForAsyncAck(ctx, forwardAddr.PacketId, forwardAddr.Address)
}

for _, enabledChan := range state.FeeEnabledChannels {
Expand Down
2 changes: 1 addition & 1 deletion modules/apps/29-fee/keeper/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ func (suite *KeeperTestSuite) TestExportGenesis() {
suite.chainA.GetSimApp().IBCFeeKeeper.SetCounterpartyAddress(suite.chainA.GetContext(), sender, counterparty, ibctesting.FirstChannelID)

// set forward relayer address
suite.chainA.GetSimApp().IBCFeeKeeper.SetForwardRelayerAddress(suite.chainA.GetContext(), packetID, sender)
suite.chainA.GetSimApp().IBCFeeKeeper.SetRelayerAddressForAsyncAck(suite.chainA.GetContext(), packetID, sender)

// export genesis
genesisState := suite.chainA.GetSimApp().IBCFeeKeeper.ExportGenesis(suite.chainA.GetContext())
Expand Down
8 changes: 4 additions & 4 deletions modules/apps/29-fee/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,14 @@ func (k Keeper) GetAllRelayerAddresses(ctx sdk.Context) []types.RegisteredRelaye
return registeredAddrArr
}

// SetForwardRelayerAddress sets the forward relayer address during OnRecvPacket in case of async acknowledgement
func (k Keeper) SetForwardRelayerAddress(ctx sdk.Context, packetId channeltypes.PacketId, address string) {
// SetRelayerAddressForAsyncAck sets the forward relayer address during OnRecvPacket in case of async acknowledgement
func (k Keeper) SetRelayerAddressForAsyncAck(ctx sdk.Context, packetId channeltypes.PacketId, address string) {
store := ctx.KVStore(k.storeKey)
store.Set(types.KeyForwardRelayerAddress(packetId), []byte(address))
}

// GetForwardRelayerAddress gets forward relayer address for a particular packet
func (k Keeper) GetForwardRelayerAddress(ctx sdk.Context, packetId channeltypes.PacketId) (string, bool) {
// GetRelayerAddressForAsyncAck gets forward relayer address for a particular packet
func (k Keeper) GetRelayerAddressForAsyncAck(ctx sdk.Context, packetId channeltypes.PacketId) (string, bool) {
store := ctx.KVStore(k.storeKey)
key := types.KeyForwardRelayerAddress(packetId)
if !store.Has(key) {
Expand Down
15 changes: 12 additions & 3 deletions modules/apps/29-fee/keeper/relay.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package keeper

import (
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
capabilitytypes "github.com/cosmos/cosmos-sdk/x/capability/types"

"github.com/cosmos/ibc-go/v3/modules/apps/29-fee/types"
Expand All @@ -24,11 +25,19 @@ func (k Keeper) WriteAcknowledgement(ctx sdk.Context, chanCap *capabilitytypes.C

// retrieve the forward relayer that was stored in `onRecvPacket`
packetId := channeltypes.NewPacketId(packet.GetSourceChannel(), packet.GetSourcePort(), packet.GetSequence())
relayer, _ := k.GetForwardRelayerAddress(ctx, packetId)

k.DeleteForwardRelayerAddress(ctx, packetId)
relayer, found := k.GetRelayerAddressForAsyncAck(ctx, packetId)
if !found {
return sdkerrors.Wrapf(types.ErrRelayerNotFoundForAsyncAck, "no relayer address stored for async acknowledgement for packet with portID: %s, channelID: %s, sequence: %d", packetId.PortId, packetId.ChannelId, packetId.Sequence)
}

// it is possible that a relayer has not registered a counterparty address.
// if there is no registered counterparty address then write acknowledgement with empty relayer address and refund recv_fee.
forwardRelayer, _ := k.GetCounterpartyAddress(ctx, relayer, packet.GetSourceChannel())

ack := types.NewIncentivizedAcknowledgement(relayer, acknowledgement.Acknowledgement(), acknowledgement.Success())
ack := types.NewIncentivizedAcknowledgement(forwardRelayer, acknowledgement.Acknowledgement(), acknowledgement.Success())

k.DeleteForwardRelayerAddress(ctx, packetId)

// ics4Wrapper may be core IBC or higher-level middleware
return k.ics4Wrapper.WriteAcknowledgement(ctx, chanCap, packet, ack)
Expand Down
15 changes: 8 additions & 7 deletions modules/apps/29-fee/keeper/relay_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@ func (suite *KeeperTestSuite) TestWriteAcknowledgementAsync() {
}{
{
"success",
func() {},
true,
},
{
"forward relayer address is successfully deleted",
func() {
suite.chainB.GetSimApp().IBCFeeKeeper.SetForwardRelayerAddress(suite.chainB.GetContext(), channeltypes.NewPacketId(suite.path.EndpointA.ChannelID, suite.path.EndpointA.ChannelConfig.PortID, 1), suite.chainA.SenderAccount.GetAddress().String())
suite.chainB.GetSimApp().IBCFeeKeeper.SetRelayerAddressForAsyncAck(suite.chainB.GetContext(), channeltypes.NewPacketId(suite.path.EndpointA.ChannelID, suite.path.EndpointA.ChannelConfig.PortID, 1), suite.chainA.SenderAccount.GetAddress().String())
suite.chainB.GetSimApp().IBCFeeKeeper.SetCounterpartyAddress(suite.chainB.GetContext(), suite.chainA.SenderAccount.GetAddress().String(), suite.chainB.SenderAccount.GetAddress().String(), suite.path.EndpointA.ChannelID)
},
true,
},
{
"relayer address not set for async WriteAcknowledgement",
func() {},
false,
},
}

for _, tc := range testCases {
Expand Down Expand Up @@ -56,7 +57,7 @@ func (suite *KeeperTestSuite) TestWriteAcknowledgementAsync() {

if tc.expPass {
suite.Require().NoError(err)
_, found := suite.chainB.GetSimApp().IBCFeeKeeper.GetForwardRelayerAddress(suite.chainB.GetContext(), channeltypes.NewPacketId(suite.path.EndpointA.ChannelID, suite.path.EndpointA.ChannelConfig.PortID, 1))
_, found := suite.chainB.GetSimApp().IBCFeeKeeper.GetRelayerAddressForAsyncAck(suite.chainB.GetContext(), channeltypes.NewPacketId(suite.path.EndpointA.ChannelID, suite.path.EndpointA.ChannelConfig.PortID, 1))
suite.Require().False(found)
} else {
suite.Require().Error(err)
Expand Down
1 change: 1 addition & 0 deletions modules/apps/29-fee/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ var (
ErrCounterpartyAddressEmpty = sdkerrors.Register(ModuleName, 7, "counterparty address must not be empty")
ErrForwardRelayerAddressNotFound = sdkerrors.Register(ModuleName, 8, "forward relayer address not found")
ErrFeeNotEnabled = sdkerrors.Register(ModuleName, 9, "fee module is not enabled for this channel. If this error occurs after channel setup, fee module may not be enabled")
ErrRelayerNotFoundForAsyncAck = sdkerrors.Register(ModuleName, 10, "relayer address must be stored for async WriteAcknowledgement")
)
22 changes: 18 additions & 4 deletions modules/apps/29-fee/types/genesis.go
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
package types

import (
"strings"

sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"

host "github.com/cosmos/ibc-go/v3/modules/core/24-host"
)

// NewGenesisState creates a 29-fee GenesisState instance.
func NewGenesisState(identifiedFees []IdentifiedPacketFee, feeEnabledChannels []FeeEnabledChannel, registeredRelayers []RegisteredRelayerAddress) *GenesisState {
func NewGenesisState(identifiedFees []IdentifiedPacketFee, feeEnabledChannels []FeeEnabledChannel, registeredRelayers []RegisteredRelayerAddress, forwardRelayers []ForwardRelayerAddress) *GenesisState {
return &GenesisState{
IdentifiedFees: identifiedFees,
FeeEnabledChannels: feeEnabledChannels,
RegisteredRelayers: registeredRelayers,
ForwardRelayers: forwardRelayers,
}
}

// DefaultGenesisState returns a GenesisState with "transfer" as the default PortID.
func DefaultGenesisState() *GenesisState {
return &GenesisState{
IdentifiedFees: []IdentifiedPacketFee{},
ForwardRelayers: []ForwardRelayerAddress{},
FeeEnabledChannels: []FeeEnabledChannel{},
RegisteredRelayers: []RegisteredRelayerAddress{},
}
Expand Down Expand Up @@ -48,15 +52,25 @@ func (gs GenesisState) Validate() error {

// Validate RegisteredRelayers
for _, rel := range gs.RegisteredRelayers {
_, err := sdk.AccAddressFromBech32(rel.Address)
if err != nil {
if _, err := sdk.AccAddressFromBech32(rel.Address); err != nil {
return sdkerrors.Wrap(err, "failed to convert source relayer address into sdk.AccAddress")
}

if rel.CounterpartyAddress == "" {
if strings.TrimSpace(rel.CounterpartyAddress) == "" {
return ErrCounterpartyAddressEmpty
}
}

// Validate ForwardRelayers
for _, rel := range gs.ForwardRelayers {
if _, err := sdk.AccAddressFromBech32(rel.Address); err != nil {
return sdkerrors.Wrap(err, "failed to convert forward relayer address into sdk.AccAddress")
}

if err := rel.PacketId.Validate(); err != nil {
return err
}
}

return nil
}
42 changes: 33 additions & 9 deletions modules/apps/29-fee/types/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@ var (

func TestValidateGenesis(t *testing.T) {
var (
packetId channeltypes.PacketId
fee types.Fee
refundAcc string
sender string
counterparty string
portID string
channelID string
seq uint64
packetId channeltypes.PacketId
fee types.Fee
refundAcc string
sender string
forwardAddr string
counterparty string
portID string
channelID string
packetChannelID string
seq uint64
)

testCases := []struct {
Expand Down Expand Up @@ -118,7 +120,21 @@ func TestValidateGenesis(t *testing.T) {
{
"invalid RegisteredRelayers: invalid counterparty",
func() {
counterparty = ""
counterparty = " "
},
false,
},
{
"invalid ForwardRelayerAddress: invalid forwardAddr",
func() {
forwardAddr = ""
},
false,
},
{
"invalid ForwardRelayerAddress: invalid packet",
func() {
packetChannelID = "1"
},
false,
},
Expand All @@ -127,6 +143,7 @@ func TestValidateGenesis(t *testing.T) {
for _, tc := range testCases {
portID = transfertypes.PortID
channelID = ibctesting.FirstChannelID
packetChannelID = ibctesting.FirstChannelID
seq = uint64(1)

// build PacketId & Fee
Expand All @@ -146,6 +163,7 @@ func TestValidateGenesis(t *testing.T) {
// relayer addresses
sender = addr1
counterparty = addr2
forwardAddr = addr2

tc.malleate()

Expand All @@ -170,6 +188,12 @@ func TestValidateGenesis(t *testing.T) {
CounterpartyAddress: counterparty,
},
},
ForwardRelayers: []types.ForwardRelayerAddress{
{
Address: forwardAddr,
PacketId: channeltypes.NewPacketId(packetChannelID, portID, 1),
},
},
}

err := genState.Validate()
Expand Down