diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e8b308835fe..5489e3bc0566 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i * (crypto/keyring) [#20212](https://github.com/cosmos/cosmos-sdk/pull/20212) Expose the db keyring used in the keystore. * (genutil) [#19971](https://github.com/cosmos/cosmos-sdk/pull/19971) Allow manually setting the consensus key type in genesis * (debug) [#20328](https://github.com/cosmos/cosmos-sdk/pull/20328) Add consensus address for debug cmd. +* (client) [#20356](https://github.com/cosmos/cosmos-sdk/pull/20356) Overwrite client context instead of setting new one ### Improvements diff --git a/client/cmd.go b/client/cmd.go index e817649d24dc..d4f0693d7e4a 100644 --- a/client/cmd.go +++ b/client/cmd.go @@ -359,14 +359,17 @@ func GetClientContextFromCmd(cmd *cobra.Command) Context { // SetCmdClientContext sets a command's Context value to the provided argument. // If the context has not been set, set the given context as the default. func SetCmdClientContext(cmd *cobra.Command, clientCtx Context) error { - var cmdCtx context.Context - - if cmd.Context() == nil { + cmdCtx := cmd.Context() + if cmdCtx == nil { cmdCtx = context.Background() + } + + v := cmd.Context().Value(ClientContextKey) + if clientCtxPtr, ok := v.(*Context); ok { + *clientCtxPtr = clientCtx } else { - cmdCtx = cmd.Context() + cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx)) } - cmd.SetContext(context.WithValue(cmdCtx, ClientContextKey, &clientCtx)) return nil } diff --git a/client/cmd_test.go b/client/cmd_test.go index 81d5719ccfb6..559d31d39f40 100644 --- a/client/cmd_test.go +++ b/client/cmd_test.go @@ -79,11 +79,13 @@ func TestSetCmdClientContextHandler(t *testing.T) { name string expectedContext client.Context args []string + ctx context.Context }{ { "no flags set", initClientCtx, []string{}, + context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}), }, { "flags set", @@ -91,6 +93,7 @@ func TestSetCmdClientContextHandler(t *testing.T) { []string{ fmt.Sprintf("--%s=new-chain-id", flags.FlagChainID), }, + context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}), }, { "flags set with space", @@ -99,6 +102,25 @@ func TestSetCmdClientContextHandler(t *testing.T) { fmt.Sprintf("--%s", flags.FlagHome), "/tmp/dir", }, + context.Background(), + }, + { + "no context provided", + initClientCtx.WithHomeDir("/tmp/noctx"), + []string{ + fmt.Sprintf("--%s", flags.FlagHome), + "/tmp/noctx", + }, + nil, + }, + { + "with invalid client value in the context", + initClientCtx.WithHomeDir("/tmp/invalid"), + []string{ + fmt.Sprintf("--%s", flags.FlagHome), + "/tmp/invalid", + }, + context.WithValue(context.Background(), client.ClientContextKey, "invalid"), }, } @@ -106,13 +128,11 @@ func TestSetCmdClientContextHandler(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { - ctx := context.WithValue(context.Background(), client.ClientContextKey, &client.Context{}) - cmd := newCmd() _ = testutil.ApplyMockIODiscardOutErr(cmd) cmd.SetArgs(tc.args) - require.NoError(t, cmd.ExecuteContext(ctx)) + require.NoError(t, cmd.ExecuteContext(tc.ctx)) clientCtx := client.GetClientContextFromCmd(cmd) require.Equal(t, tc.expectedContext, clientCtx)