diff --git a/modules/light-clients/06-solomachine/types/update.go b/modules/light-clients/06-solomachine/types/update.go index fd775544b61..eb3941b89d2 100644 --- a/modules/light-clients/06-solomachine/types/update.go +++ b/modules/light-clients/06-solomachine/types/update.go @@ -23,16 +23,7 @@ func (cs ClientState) CheckHeaderAndUpdateState( return nil, nil, err } - // TODO: Remove this type assertion, replace with misbehaviour checking and update state - smHeader, ok := msg.(*Header) - if !ok { - return nil, nil, sdkerrors.Wrapf( - clienttypes.ErrInvalidHeader, "expected %T, got %T", &Header{}, msg, - ) - } - - clientState, consensusState := update(&cs, smHeader) - return clientState, consensusState, nil + return cs.UpdateState(ctx, cdc, clientStore, msg) } // VerifyClientMessage introspects the provided ClientMessage and checks its validity @@ -105,16 +96,16 @@ func (cs ClientState) verifyMisbehaviour(ctx sdk.Context, cdc codec.BinaryCodec, return nil } -// update the consensus state to the new public key and an incremented sequence -func update(clientState *ClientState, header *Header) (*ClientState, *ConsensusState) { +// UpdateState updates the consensus state to the new public key and an incremented sequence. +func (cs ClientState) UpdateState(ctx sdk.Context, cdc codec.BinaryCodec, clientStore sdk.KVStore, clientMsg exported.ClientMessage) (exported.ClientState, exported.ConsensusState, error) { + smHeader := clientMsg.(*Header) consensusState := &ConsensusState{ - PublicKey: header.NewPublicKey, - Diversifier: header.NewDiversifier, - Timestamp: header.Timestamp, + PublicKey: smHeader.NewPublicKey, + Diversifier: smHeader.NewDiversifier, + Timestamp: smHeader.Timestamp, } - // increment sequence number - clientState.Sequence++ - clientState.ConsensusState = consensusState - return clientState, consensusState + cs.Sequence++ + cs.ConsensusState = consensusState + return &cs, consensusState, nil } diff --git a/modules/light-clients/06-solomachine/types/update_test.go b/modules/light-clients/06-solomachine/types/update_test.go index 336b9926be0..ddd239ab9a4 100644 --- a/modules/light-clients/06-solomachine/types/update_test.go +++ b/modules/light-clients/06-solomachine/types/update_test.go @@ -328,7 +328,7 @@ func (suite *SoloMachineTestSuite) TestVerifyClientMessageHeader() { // setup test tc.setup() - err := clientState.VerifyClientMessage(suite.chainA.GetContext(), suite.chainA.Codec, nil, clientMsg) + err := clientState.VerifyClientMessage(suite.chainA.GetContext(), suite.chainA.Codec, suite.store, clientMsg) if tc.expPass { suite.Require().NoError(err) @@ -560,7 +560,7 @@ func (suite *SoloMachineTestSuite) TestVerifyClientMessageMisbehaviour() { // setup test tc.setup() - err := clientState.VerifyClientMessage(suite.chainA.GetContext(), suite.chainA.Codec, nil, clientMsg) + err := clientState.VerifyClientMessage(suite.chainA.GetContext(), suite.chainA.Codec, suite.store, clientMsg) if tc.expPass { suite.Require().NoError(err) @@ -571,3 +571,56 @@ func (suite *SoloMachineTestSuite) TestVerifyClientMessageMisbehaviour() { } } } + +func (suite *SoloMachineTestSuite) TestUpdateState() { + var ( + clientState exported.ClientState + clientMsg exported.ClientMessage + ) + + // test singlesig and multisig public keys + for _, solomachine := range []*ibctesting.Solomachine{suite.solomachine, suite.solomachineMulti} { + + testCases := []struct { + name string + setup func() + expPass bool + }{ + { + "successful update", + func() { + clientState = solomachine.ClientState() + clientMsg = solomachine.CreateHeader() + }, + true, + }, + } + + for _, tc := range testCases { + tc := tc + + suite.Run(tc.name, func() { + // setup test + tc.setup() + + clientState, ok := clientState.(*types.ClientState) + if ok { + cs, consensusState, err := clientState.UpdateState(suite.chainA.GetContext(), suite.chainA.Codec, suite.store, clientMsg) + + if tc.expPass { + suite.Require().NoError(err) + suite.Require().Equal(clientMsg.(*types.Header).NewPublicKey, cs.(*types.ClientState).ConsensusState.PublicKey) + suite.Require().Equal(false, cs.(*types.ClientState).IsFrozen) + suite.Require().Equal(clientMsg.(*types.Header).Sequence+1, cs.(*types.ClientState).Sequence) + suite.Require().Equal(consensusState, cs.(*types.ClientState).ConsensusState) + } else { + suite.Require().Error(err) + suite.Require().Nil(clientState) + suite.Require().Nil(consensusState) + } + } + + }) + } + } +}