diff --git a/modules/core/03-connection/keeper/handshake_test.go b/modules/core/03-connection/keeper/handshake_test.go index 6109e6c40f1..880ae768ab0 100644 --- a/modules/core/03-connection/keeper/handshake_test.go +++ b/modules/core/03-connection/keeper/handshake_test.go @@ -120,9 +120,7 @@ func (suite *KeeperTestSuite) TestConnOpenTry() { delayPeriod = uint64(time.Hour.Nanoseconds()) // set delay period on counterparty to non-zero value - conn := path.EndpointA.GetConnection() - conn.DelayPeriod = delayPeriod - suite.chainA.App.GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainA.GetContext(), path.EndpointA.ConnectionID, conn) + path.EndpointA.UpdateConnection(func(connection *types.ConnectionEnd) { connection.DelayPeriod = delayPeriod }) // commit in order for proof to return correct value suite.coordinator.CommitBlock(suite.chainA) @@ -341,12 +339,7 @@ func (suite *KeeperTestSuite) TestConnOpenAck() { suite.Require().NoError(err) // modify connB to set counterparty connection identifier to wrong identifier - connection, found := suite.chainA.App.GetIBCKeeper().ConnectionKeeper.GetConnection(suite.chainA.GetContext(), path.EndpointA.ConnectionID) - suite.Require().True(found) - - connection.Counterparty.ConnectionId = "badconnectionid" - - suite.chainA.App.GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainA.GetContext(), path.EndpointA.ConnectionID, connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.Counterparty.ConnectionId = "badconnectionid" }) err = path.EndpointA.UpdateClient() suite.Require().NoError(err) diff --git a/modules/core/03-connection/keeper/verify_test.go b/modules/core/03-connection/keeper/verify_test.go index 0d49dd5de0d..95b0127db6f 100644 --- a/modules/core/03-connection/keeper/verify_test.go +++ b/modules/core/03-connection/keeper/verify_test.go @@ -30,9 +30,7 @@ func (suite *KeeperTestSuite) TestVerifyClientState() { }{ {"verification success", func() {}, true}, {"client state not found", func() { - connection := path.EndpointA.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, false}, {"consensus state for proof height not found", func() { heightDiff = 5 @@ -95,9 +93,7 @@ func (suite *KeeperTestSuite) TestVerifyClientConsensusState() { }{ {"verification success", func() {}, true}, {"client state not found", func() { - connection := path.EndpointA.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, false}, {"consensus state not found", func() { heightDiff = 5 @@ -169,17 +165,13 @@ func (suite *KeeperTestSuite) TestVerifyConnectionState() { }{ {"verification success", func() {}, true}, {"client state not found - changed client ID", func() { - connection := path.EndpointA.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, false}, {"consensus state not found - increased proof height", func() { heightDiff = 5 }, false}, {"verification failed - connection state is different than proof", func() { - connection := path.EndpointA.GetConnection() - connection.State = types.TRYOPEN - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.State = types.TRYOPEN }) }, false}, {"client status is not active - client is expired", func() { clientState := path.EndpointA.GetClientState().(*ibctm.ClientState) @@ -234,9 +226,7 @@ func (suite *KeeperTestSuite) TestVerifyChannelState() { }{ {"verification success", func() {}, true}, {"client state not found- changed client ID", func() { - connection := path.EndpointA.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, false}, {"consensus state not found - increased proof height", func() { heightDiff = 5 @@ -314,9 +304,7 @@ func (suite *KeeperTestSuite) TestVerifyPacketCommitment() { timePerBlock = 1 }, false}, {"client state not found- changed client ID", func() { - connection := path.EndpointB.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointB.SetConnection(connection) + path.EndpointB.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, false}, {"consensus state not found - increased proof height", func() { heightDiff = 5 @@ -406,9 +394,7 @@ func (suite *KeeperTestSuite) TestVerifyPacketAcknowledgement() { timePerBlock = 1 }, false}, {"client state not found- changed client ID", func() { - connection := path.EndpointA.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, false}, {"consensus state not found - increased proof height", func() { heightDiff = 5 @@ -507,9 +493,7 @@ func (suite *KeeperTestSuite) TestVerifyPacketReceiptAbsence() { timePerBlock = 1 }, false}, {"client state not found - changed client ID", func() { - connection := path.EndpointA.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, false}, {"consensus state not found - increased proof height", func() { heightDiff = 5 @@ -613,9 +597,7 @@ func (suite *KeeperTestSuite) TestVerifyNextSequenceRecv() { timePerBlock = 1 }, false}, {"client state not found- changed client ID", func() { - connection := path.EndpointA.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, false}, {"consensus state not found - increased proof height", func() { heightDiff = 5 @@ -709,9 +691,7 @@ func (suite *KeeperTestSuite) TestVerifyUpgradeErrorReceipt() { { name: "fails with bad client id", malleate: func() { - connection := path.EndpointB.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointB.SetConnection(connection) + path.EndpointB.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, expPass: false, }, @@ -792,9 +772,7 @@ func (suite *KeeperTestSuite) TestVerifyUpgrade() { { name: "fails with bad client id", malleate: func() { - connection := path.EndpointB.GetConnection() - connection.ClientId = ibctesting.InvalidID - path.EndpointB.SetConnection(connection) + path.EndpointB.UpdateConnection(func(c *types.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) }, expPass: false, }, diff --git a/modules/core/04-channel/keeper/handshake_test.go b/modules/core/04-channel/keeper/handshake_test.go index 8b1366b0539..f99db4eac47 100644 --- a/modules/core/04-channel/keeper/handshake_test.go +++ b/modules/core/04-channel/keeper/handshake_test.go @@ -55,15 +55,10 @@ func (suite *KeeperTestSuite) TestChanOpenInit() { path.SetupConnections() // modify connA versions - conn := path.EndpointA.GetConnection() + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { + c.Versions = append(c.Versions, connectiontypes.NewVersion("2", []string{"ORDER_ORDERED", "ORDER_UNORDERED"})) + }) - version := connectiontypes.NewVersion("2", []string{"ORDER_ORDERED", "ORDER_UNORDERED"}) - conn.Versions = append(conn.Versions, version) - - suite.chainA.App.GetIBCKeeper().ConnectionKeeper.SetConnection( - suite.chainA.GetContext(), - path.EndpointA.ConnectionID, conn, - ) features = []string{"ORDER_ORDERED", "ORDER_UNORDERED"} suite.chainA.CreatePortCapability(suite.chainA.GetSimApp().ScopedIBCMockKeeper, ibctesting.MockPort) portCap = suite.chainA.GetPortCapability(ibctesting.MockPort) @@ -72,15 +67,10 @@ func (suite *KeeperTestSuite) TestChanOpenInit() { path.SetupConnections() // modify connA versions to only support UNORDERED channels - conn := path.EndpointA.GetConnection() + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { + c.Versions = []*connectiontypes.Version{connectiontypes.NewVersion("1", []string{"ORDER_UNORDERED"})} + }) - version := connectiontypes.NewVersion("1", []string{"ORDER_UNORDERED"}) - conn.Versions = []*connectiontypes.Version{version} - - suite.chainA.App.GetIBCKeeper().ConnectionKeeper.SetConnection( - suite.chainA.GetContext(), - path.EndpointA.ConnectionID, conn, - ) // NOTE: Opening UNORDERED channels is still expected to pass but ORDERED channels should fail features = []string{"ORDER_UNORDERED"} suite.chainA.CreatePortCapability(suite.chainA.GetSimApp().ScopedIBCMockKeeper, ibctesting.MockPort) @@ -225,15 +215,10 @@ func (suite *KeeperTestSuite) TestChanOpenTry() { suite.Require().NoError(err) // modify connB versions - conn := path.EndpointB.GetConnection() + path.EndpointB.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { + c.Versions = append(c.Versions, connectiontypes.NewVersion("2", []string{"ORDER_ORDERED", "ORDER_UNORDERED"})) + }) - version := connectiontypes.NewVersion("2", []string{"ORDER_ORDERED", "ORDER_UNORDERED"}) - conn.Versions = append(conn.Versions, version) - - suite.chainB.App.GetIBCKeeper().ConnectionKeeper.SetConnection( - suite.chainB.GetContext(), - path.EndpointB.ConnectionID, conn, - ) suite.chainB.CreatePortCapability(suite.chainB.GetSimApp().ScopedIBCMockKeeper, ibctesting.MockPort) portCap = suite.chainB.GetPortCapability(ibctesting.MockPort) }, false}, @@ -244,15 +229,10 @@ func (suite *KeeperTestSuite) TestChanOpenTry() { suite.Require().NoError(err) // modify connA versions to only support UNORDERED channels - conn := path.EndpointA.GetConnection() + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { + c.Versions = []*connectiontypes.Version{connectiontypes.NewVersion("1", []string{"ORDER_UNORDERED"})} + }) - version := connectiontypes.NewVersion("1", []string{"ORDER_UNORDERED"}) - conn.Versions = []*connectiontypes.Version{version} - - suite.chainA.App.GetIBCKeeper().ConnectionKeeper.SetConnection( - suite.chainA.GetContext(), - path.EndpointA.ConnectionID, conn, - ) suite.chainA.CreatePortCapability(suite.chainA.GetSimApp().ScopedIBCMockKeeper, ibctesting.MockPort) portCap = suite.chainA.GetPortCapability(ibctesting.MockPort) }, false}, diff --git a/modules/core/04-channel/keeper/packet_test.go b/modules/core/04-channel/keeper/packet_test.go index 12308037356..6e9f471d3c4 100644 --- a/modules/core/04-channel/keeper/packet_test.go +++ b/modules/core/04-channel/keeper/packet_test.go @@ -63,9 +63,7 @@ func (suite *KeeperTestSuite) TestSendPacket() { solomachine := ibctesting.NewSolomachine(suite.T(), suite.chainA.Codec, "solomachinesingle", "testing", 1) path.EndpointA.ClientID = clienttypes.FormatClientIdentifier(exported.Solomachine, 10) path.EndpointA.SetClientState(solomachine.ClientState()) - connection := path.EndpointA.GetConnection() - connection.ClientId = path.EndpointA.ClientID - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.ClientId = path.EndpointA.ClientID }) channelCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) }, true}, @@ -78,9 +76,8 @@ func (suite *KeeperTestSuite) TestSendPacket() { solomachine := ibctesting.NewSolomachine(suite.T(), suite.chainA.Codec, "solomachinesingle", "testing", 1) path.EndpointA.ClientID = clienttypes.FormatClientIdentifier(exported.Solomachine, 10) path.EndpointA.SetClientState(solomachine.ClientState()) - connection := path.EndpointA.GetConnection() - connection.ClientId = path.EndpointA.ClientID - path.EndpointA.SetConnection(connection) + + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.ClientId = path.EndpointA.ClientID }) channelCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) }, true}, @@ -134,9 +131,7 @@ func (suite *KeeperTestSuite) TestSendPacket() { sourceChannel = path.EndpointA.ChannelID // change connection client ID - connection := path.EndpointA.GetConnection() - connection.ClientId = ibctesting.InvalidID - suite.chainA.App.GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainA.GetContext(), path.EndpointA.ConnectionID, connection) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.ClientId = ibctesting.InvalidID }) channelCap = suite.chainA.GetChannelCapability(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID) }, false}, @@ -185,11 +180,12 @@ func (suite *KeeperTestSuite) TestSendPacket() { solomachine := ibctesting.NewSolomachine(suite.T(), suite.chainA.Codec, "solomachinesingle", "testing", 1) path.EndpointA.ClientID = clienttypes.FormatClientIdentifier(exported.Solomachine, 10) path.EndpointA.SetClientState(solomachine.ClientState()) - connection := path.EndpointA.GetConnection() - connection.ClientId = path.EndpointA.ClientID - path.EndpointA.SetConnection(connection) + + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.ClientId = path.EndpointA.ClientID }) clientState := path.EndpointA.GetClientState() + connection := path.EndpointA.GetConnection() + timestamp, err := suite.chainA.App.GetIBCKeeper().ConnectionKeeper.GetTimestampAtHeight(suite.chainA.GetContext(), connection, clientState.GetLatestHeight()) suite.Require().NoError(err) diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index d970cfa38f3..c5c91aaff68 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -77,10 +77,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeInit() { { "invalid proposed channel connection state", func() { - connectionEnd := path.EndpointA.GetConnection() - connectionEnd.State = connectiontypes.UNINITIALIZED - - suite.chainA.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainA.GetContext(), "connection-100", connectionEnd) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) upgradeFields.ConnectionHops = []string{"connection-100"} }, false, @@ -183,9 +180,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeTry() { { "invalid connection state", func() { - connectionEnd := path.EndpointB.GetConnection() - connectionEnd.State = connectiontypes.UNINITIALIZED - suite.chainB.GetSimApp().GetIBCKeeper().ConnectionKeeper.SetConnection(suite.chainB.GetContext(), path.EndpointB.ConnectionID, connectionEnd) + path.EndpointB.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) }, connectiontypes.ErrInvalidConnectionState, }, @@ -706,9 +701,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() { { "invalid connection state", func() { - connectionEnd := path.EndpointA.GetConnection() - connectionEnd.State = connectiontypes.UNINITIALIZED - path.EndpointA.SetConnection(connectionEnd) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) }, connectiontypes.ErrInvalidConnectionState, }, @@ -1074,9 +1067,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeConfirm() { { "invalid connection state", func() { - connectionEnd := path.EndpointB.GetConnection() - connectionEnd.State = connectiontypes.UNINITIALIZED - path.EndpointB.SetConnection(connectionEnd) + path.EndpointB.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) }, connectiontypes.ErrInvalidConnectionState, }, @@ -1340,9 +1331,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeOpen() { { "invalid connection state", func() { - connectionEnd := path.EndpointA.GetConnection() - connectionEnd.State = connectiontypes.UNINITIALIZED - path.EndpointA.SetConnection(connectionEnd) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) }, connectiontypes.ErrInvalidConnectionState, }, @@ -2125,9 +2114,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { { "connection not open", func() { - connectionEnd := path.EndpointA.GetConnection() - connectionEnd.State = connectiontypes.UNINITIALIZED - path.EndpointA.SetConnection(connectionEnd) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) }, connectiontypes.ErrInvalidConnectionState, }, @@ -2316,9 +2303,7 @@ func (suite *KeeperTestSuite) TestStartFlush() { { "connection state is not in OPEN state", func() { - conn := path.EndpointB.GetConnection() - conn.State = connectiontypes.INIT - path.EndpointB.SetConnection(conn) + path.EndpointB.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.INIT }) }, connectiontypes.ErrInvalidConnectionState, }, @@ -2423,9 +2408,7 @@ func (suite *KeeperTestSuite) TestValidateUpgradeFields() { { name: "fails when connection is not open", malleate: func() { - connection := path.EndpointA.GetConnection() - connection.State = connectiontypes.UNINITIALIZED - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { c.State = connectiontypes.UNINITIALIZED }) }, expPass: false, }, @@ -2435,9 +2418,9 @@ func (suite *KeeperTestSuite) TestValidateUpgradeFields() { // update channel version first so that existing channel end is not identical to proposed upgrade proposedUpgrade.Version = mock.UpgradeVersion - connection := path.EndpointA.GetConnection() - connection.Versions = []*connectiontypes.Version{} - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { + c.Versions = []*connectiontypes.Version{} + }) }, expPass: false, }, @@ -2447,11 +2430,9 @@ func (suite *KeeperTestSuite) TestValidateUpgradeFields() { // update channel version first so that existing channel end is not identical to proposed upgrade proposedUpgrade.Version = mock.UpgradeVersion - connection := path.EndpointA.GetConnection() - connection.Versions = []*connectiontypes.Version{ - connectiontypes.NewVersion("1", []string{"ORDER_ORDERED"}), - } - path.EndpointA.SetConnection(connection) + path.EndpointA.UpdateConnection(func(c *connectiontypes.ConnectionEnd) { + c.Versions = []*connectiontypes.Version{connectiontypes.NewVersion("1", []string{"ORDER_ORDERED"})} + }) }, expPass: false, }, diff --git a/testing/endpoint.go b/testing/endpoint.go index 45a7d57c788..655f51b1b8d 100644 --- a/testing/endpoint.go +++ b/testing/endpoint.go @@ -892,3 +892,12 @@ func (endpoint *Endpoint) GetProposedUpgrade() channeltypes.Upgrade { return upgrade } + +// UpdateConnection updates the connection associated with the given endpoint. It accepts a +// closure which takes a connection allowing the caller to modify the connection fields. +func (endpoint *Endpoint) UpdateConnection(updater func(connection *connectiontypes.ConnectionEnd)) { + connection := endpoint.GetConnection() + updater(&connection) + + endpoint.SetConnection(connection) +}