diff --git a/modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go b/modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go index daa037e447e..7f727a3a594 100644 --- a/modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go +++ b/modules/apps/27-interchain-accounts/controller/ibc_middleware_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/suite" "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller" + controllerkeeper "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/keeper" "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/types" icatypes "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/types" fee "github.com/cosmos/ibc-go/v6/modules/apps/29-fee" @@ -840,3 +841,80 @@ func (suite *InterchainAccountsTestSuite) TestGetAppVersion() { suite.Require().True(found) suite.Require().Equal(path.EndpointA.ChannelConfig.Version, appVersion) } + +func (suite *InterchainAccountsTestSuite) TestInFlightHandshakeRespectsGoAPICaller() { + path := NewICAPath(suite.chainA, suite.chainB) + suite.coordinator.SetupConnections(path) + + // initiate a channel handshake such that channel.State == INIT + err := RegisterInterchainAccount(path.EndpointA, suite.chainA.SenderAccount.GetAddress().String()) + suite.Require().NoError(err) + + // attempt to start a second handshake via the controller msg server + msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper) + msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), TestVersion) + + res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount) + suite.Require().Error(err) + suite.Require().Nil(res) +} + +func (suite *InterchainAccountsTestSuite) TestInFlightHandshakeRespectsMsgServerCaller() { + path := NewICAPath(suite.chainA, suite.chainB) + suite.coordinator.SetupConnections(path) + + // initiate a channel handshake such that channel.State == INIT + msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper) + msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), TestVersion) + + res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount) + suite.Require().NotNil(res) + suite.Require().NoError(err) + + // attempt to start a second handshake via the legacy Go API + err = RegisterInterchainAccount(path.EndpointA, suite.chainA.SenderAccount.GetAddress().String()) + suite.Require().Error(err) +} + +func (suite *InterchainAccountsTestSuite) TestClosedChannelReopensWithMsgServer() { + path := NewICAPath(suite.chainA, suite.chainB) + suite.coordinator.SetupConnections(path) + + err := SetupICAPath(path, suite.chainA.SenderAccount.GetAddress().String()) + suite.Require().NoError(err) + + // set the channel state to closed + err = path.EndpointA.SetChannelClosed() + suite.Require().NoError(err) + err = path.EndpointB.SetChannelClosed() + suite.Require().NoError(err) + + // reset endpoint channel ids + path.EndpointA.ChannelID = "" + path.EndpointB.ChannelID = "" + + // fetch the next channel sequence before reinitiating the channel handshake + channelSeq := suite.chainA.GetSimApp().GetIBCKeeper().ChannelKeeper.GetNextChannelSequence(suite.chainA.GetContext()) + + // route a new MsgRegisterInterchainAccount in order to reopen the + msgServer := controllerkeeper.NewMsgServerImpl(&suite.chainA.GetSimApp().ICAControllerKeeper) + msgRegisterInterchainAccount := types.NewMsgRegisterInterchainAccount(path.EndpointA.ConnectionID, suite.chainA.SenderAccount.GetAddress().String(), path.EndpointA.ChannelConfig.Version) + + res, err := msgServer.RegisterInterchainAccount(suite.chainA.GetContext(), msgRegisterInterchainAccount) + suite.Require().NoError(err) + suite.Require().Equal(channeltypes.FormatChannelIdentifier(channelSeq), res.ChannelId) + + // assign the channel sequence to endpointA before generating proofs and initiating the TRY step + path.EndpointA.ChannelID = channeltypes.FormatChannelIdentifier(channelSeq) + + path.EndpointA.Chain.NextBlock() + + err = path.EndpointB.ChanOpenTry() + suite.Require().NoError(err) + + err = path.EndpointA.ChanOpenAck() + suite.Require().NoError(err) + + err = path.EndpointB.ChanOpenConfirm() + suite.Require().NoError(err) +} diff --git a/modules/apps/27-interchain-accounts/controller/keeper/account.go b/modules/apps/27-interchain-accounts/controller/keeper/account.go index 661348dcc30..42c8b8da522 100644 --- a/modules/apps/27-interchain-accounts/controller/keeper/account.go +++ b/modules/apps/27-interchain-accounts/controller/keeper/account.go @@ -32,6 +32,10 @@ func (k Keeper) RegisterInterchainAccount(ctx sdk.Context, connectionID, owner, return err } + if k.IsMiddlewareDisabled(ctx, portID, connectionID) && !k.IsActiveChannelClosed(ctx, connectionID, portID) { + return sdkerrors.Wrap(icatypes.ErrInvalidChannelFlow, "channel is already active or a handshake is in flight") + } + k.SetMiddlewareEnabled(ctx, portID, connectionID) _, err = k.registerInterchainAccount(ctx, connectionID, portID, version) diff --git a/modules/apps/27-interchain-accounts/controller/keeper/genesis.go b/modules/apps/27-interchain-accounts/controller/keeper/genesis.go index 83f7f4f79be..e6c4270bf18 100644 --- a/modules/apps/27-interchain-accounts/controller/keeper/genesis.go +++ b/modules/apps/27-interchain-accounts/controller/keeper/genesis.go @@ -25,6 +25,8 @@ func InitGenesis(ctx sdk.Context, keeper Keeper, state genesistypes.ControllerGe if ch.IsMiddlewareEnabled { keeper.SetMiddlewareEnabled(ctx, ch.PortId, ch.ConnectionId) + } else { + keeper.SetMiddlewareDisabled(ctx, ch.PortId, ch.ConnectionId) } } diff --git a/modules/apps/27-interchain-accounts/controller/keeper/genesis_test.go b/modules/apps/27-interchain-accounts/controller/keeper/genesis_test.go index 98b8fbdc962..ffbec8ed824 100644 --- a/modules/apps/27-interchain-accounts/controller/keeper/genesis_test.go +++ b/modules/apps/27-interchain-accounts/controller/keeper/genesis_test.go @@ -20,6 +20,12 @@ func (suite *KeeperTestSuite) TestInitGenesis() { ChannelId: ibctesting.FirstChannelID, IsMiddlewareEnabled: true, }, + { + ConnectionId: "connection-1", + PortId: "test-port-1", + ChannelId: "channel-1", + IsMiddlewareEnabled: false, + }, }, InterchainAccounts: []genesistypes.RegisteredInterchainAccount{ { @@ -40,6 +46,9 @@ func (suite *KeeperTestSuite) TestInitGenesis() { isMiddlewareEnabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareEnabled(suite.chainA.GetContext(), TestPortID, ibctesting.FirstConnectionID) suite.Require().True(isMiddlewareEnabled) + isMiddlewareDisabled := suite.chainA.GetSimApp().ICAControllerKeeper.IsMiddlewareDisabled(suite.chainA.GetContext(), "test-port-1", "connection-1") + suite.Require().True(isMiddlewareDisabled) + accountAdrr, found := suite.chainA.GetSimApp().ICAControllerKeeper.GetInterchainAccountAddress(suite.chainA.GetContext(), ibctesting.FirstConnectionID, TestPortID) suite.Require().True(found) suite.Require().Equal(interchainAccAddr.String(), accountAdrr) diff --git a/modules/apps/27-interchain-accounts/controller/keeper/keeper.go b/modules/apps/27-interchain-accounts/controller/keeper/keeper.go index addce165e47..261b2bef87a 100644 --- a/modules/apps/27-interchain-accounts/controller/keeper/keeper.go +++ b/modules/apps/27-interchain-accounts/controller/keeper/keeper.go @@ -1,6 +1,7 @@ package keeper import ( + "bytes" "fmt" "strings" @@ -146,6 +147,17 @@ func (k Keeper) GetOpenActiveChannel(ctx sdk.Context, connectionID, portID strin return "", false } +// IsActiveChannelClosed retrieves the active channel from the store and returns true if the channel state is CLOSED, otherwise false +func (k Keeper) IsActiveChannelClosed(ctx sdk.Context, connectionID, portID string) bool { + channelID, found := k.GetActiveChannelID(ctx, connectionID, portID) + if !found { + return false + } + + channel, found := k.channelKeeper.GetChannel(ctx, portID, channelID) + return found && channel.State == channeltypes.CLOSED +} + // GetAllActiveChannels returns a list of all active interchain accounts controller channels and their associated connection and port identifiers func (k Keeper) GetAllActiveChannels(ctx sdk.Context) []genesistypes.ActiveChannel { store := ctx.KVStore(k.storeKey) @@ -227,13 +239,25 @@ func (k Keeper) SetInterchainAccountAddress(ctx sdk.Context, connectionID, portI // IsMiddlewareEnabled returns true if the underlying application callbacks are enabled for given port and connection identifier pair, otherwise false func (k Keeper) IsMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) bool { store := ctx.KVStore(k.storeKey) - return store.Has(icatypes.KeyIsMiddlewareEnabled(portID, connectionID)) + return bytes.Equal(icatypes.MiddlewareEnabled, store.Get(icatypes.KeyIsMiddlewareEnabled(portID, connectionID))) +} + +// IsMiddlewareDisabled returns true if the underlying application callbacks are disabled for the given port and connection identifier pair, otherwise false +func (k Keeper) IsMiddlewareDisabled(ctx sdk.Context, portID, connectionID string) bool { + store := ctx.KVStore(k.storeKey) + return bytes.Equal(icatypes.MiddlewareDisabled, store.Get(icatypes.KeyIsMiddlewareEnabled(portID, connectionID))) } // SetMiddlewareEnabled stores a flag to indicate that the underlying application callbacks should be enabled for the given port and connection identifier pair func (k Keeper) SetMiddlewareEnabled(ctx sdk.Context, portID, connectionID string) { store := ctx.KVStore(k.storeKey) - store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), []byte{byte(1)}) + store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), icatypes.MiddlewareEnabled) +} + +// SetMiddlewareDisabled stores a flag to indicate that the underlying application callbacks should be disabled for the given port and connection identifier pair +func (k Keeper) SetMiddlewareDisabled(ctx sdk.Context, portID, connectionID string) { + store := ctx.KVStore(k.storeKey) + store.Set(icatypes.KeyIsMiddlewareEnabled(portID, connectionID), icatypes.MiddlewareDisabled) } // DeleteMiddlewareEnabled deletes the middleware enabled flag stored in state diff --git a/modules/apps/27-interchain-accounts/controller/keeper/msg_server.go b/modules/apps/27-interchain-accounts/controller/keeper/msg_server.go index 2eacba9fc26..56cd8fae372 100644 --- a/modules/apps/27-interchain-accounts/controller/keeper/msg_server.go +++ b/modules/apps/27-interchain-accounts/controller/keeper/msg_server.go @@ -4,6 +4,7 @@ import ( "context" sdk "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/controller/types" icatypes "github.com/cosmos/ibc-go/v6/modules/apps/27-interchain-accounts/types" @@ -30,6 +31,12 @@ func (s msgServer) RegisterInterchainAccount(goCtx context.Context, msg *types.M return nil, err } + if s.IsMiddlewareEnabled(ctx, portID, msg.ConnectionId) && !s.IsActiveChannelClosed(ctx, msg.ConnectionId, portID) { + return nil, sdkerrors.Wrap(icatypes.ErrInvalidChannelFlow, "channel is already active or a handshake is in flight") + } + + s.SetMiddlewareDisabled(ctx, portID, msg.ConnectionId) + channelID, err := s.registerInterchainAccount(ctx, msg.ConnectionId, portID, msg.Version) if err != nil { return nil, err diff --git a/modules/apps/27-interchain-accounts/types/keys.go b/modules/apps/27-interchain-accounts/types/keys.go index a7118df685d..2a062130ce2 100644 --- a/modules/apps/27-interchain-accounts/types/keys.go +++ b/modules/apps/27-interchain-accounts/types/keys.go @@ -42,6 +42,12 @@ var ( // IsMiddlewareEnabledPrefix defines the key prefix used to store a flag for legacy API callback routing via ibc middleware IsMiddlewareEnabledPrefix = "isMiddlewareEnabled" + + // MiddlewareEnabled is the value used to signal that controller middleware is enabled + MiddlewareEnabled = []byte{0x01} + + // MiddlewareDisabled is the value used to signal that controller midleware is disabled + MiddlewareDisabled = []byte{0x02} ) // KeyActiveChannel creates and returns a new key used for active channels store operations