diff --git a/x/wasm/internal/keeper/keeper.go b/x/wasm/internal/keeper/keeper.go index 2721b262a1..ac441b844e 100644 --- a/x/wasm/internal/keeper/keeper.go +++ b/x/wasm/internal/keeper/keeper.go @@ -120,7 +120,7 @@ func NewKeeper( authZPolicy: DefaultAuthorizationPolicy{}, paramSpace: paramSpace, } - keeper.queryPlugins = DefaultQueryPlugins(bankKeeper, stakingKeeper, distKeeper, queryRouter, &keeper).Merge(customPlugins) + keeper.queryPlugins = DefaultQueryPlugins(bankKeeper, stakingKeeper, distKeeper, channelKeeper, queryRouter, &keeper).Merge(customPlugins) for _, o := range opts { o.apply(&keeper) } @@ -268,10 +268,7 @@ func (k Keeper) instantiate(ctx sdk.Context, codeID uint64, creator, admin sdk.A prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey) // prepare querier - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddress) // instantiate wasm contract gas := gasForContract(ctx) @@ -340,10 +337,7 @@ func (k Keeper) Execute(ctx sdk.Context, contractAddress sdk.AccAddress, caller info := types.NewInfo(caller, coins) // prepare querier - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddress) gas := gasForContract(ctx) res, gasUsed, execErr := k.wasmer.Execute(codeInfo.CodeHash, env, info, msg, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) consumeGas(ctx, gasUsed) @@ -405,10 +399,7 @@ func (k Keeper) migrate(ctx sdk.Context, contractAddress sdk.AccAddress, caller env := types.NewEnv(ctx, contractAddress) // prepare querier - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddress) prefixStoreKey := types.GetContractStorePrefix(contractAddress) prefixStore := prefix.NewStore(ctx.KVStore(k.storeKey), prefixStoreKey) @@ -454,10 +445,7 @@ func (k Keeper) Sudo(ctx sdk.Context, contractAddress sdk.AccAddress, msg []byte env := types.NewEnv(ctx, contractAddress) // prepare querier - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddress) gas := gasForContract(ctx) res, gasUsed, execErr := k.wasmer.Sudo(codeInfo.CodeHash, env, msg, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gas) consumeGas(ctx, gasUsed) @@ -543,10 +531,7 @@ func (k Keeper) QuerySmart(ctx sdk.Context, contractAddr sdk.AccAddress, req []b return nil, err } // prepare querier - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) env := types.NewEnv(ctx, contractAddr) queryResult, gasUsed, qErr := k.wasmer.Query(codeInfo.CodeHash, env, req, prefixStore, cosmwasmAPI, querier, gasMeter(ctx), gasForContract(ctx)) diff --git a/x/wasm/internal/keeper/query_plugins.go b/x/wasm/internal/keeper/query_plugins.go index 91d03ba0b5..2d006f8b54 100644 --- a/x/wasm/internal/keeper/query_plugins.go +++ b/x/wasm/internal/keeper/query_plugins.go @@ -3,6 +3,7 @@ package keeper import ( "encoding/json" "fmt" + "github.com/CosmWasm/wasmd/x/wasm/internal/types" wasmvmtypes "github.com/CosmWasm/wasmvm/types" sdk "github.com/cosmos/cosmos-sdk/types" @@ -18,6 +19,15 @@ import ( type QueryHandler struct { Ctx sdk.Context Plugins QueryPlugins + Caller sdk.AccAddress +} + +func NewQueryHandler(ctx sdk.Context, plugins QueryPlugins, caller sdk.AccAddress) QueryHandler { + return QueryHandler{ + Ctx: ctx, + Plugins: plugins, + Caller: caller, + } } // -- interfaces from baseapp - so we can use the GPRQueryRouter -- @@ -51,6 +61,9 @@ func (q QueryHandler) Query(request wasmvmtypes.QueryRequest, gasLimit uint64) ( if request.Custom != nil { return q.Plugins.Custom(subctx, request.Custom) } + if request.IBC != nil { + return q.Plugins.IBC(subctx, q.Caller, request.IBC) + } if request.Staking != nil { return q.Plugins.Staking(subctx, request.Staking) } @@ -72,15 +85,17 @@ type CustomQuerier func(ctx sdk.Context, request json.RawMessage) ([]byte, error type QueryPlugins struct { Bank func(ctx sdk.Context, request *wasmvmtypes.BankQuery) ([]byte, error) Custom CustomQuerier + IBC func(ctx sdk.Context, caller sdk.AccAddress, request *wasmvmtypes.IBCQuery) ([]byte, error) Staking func(ctx sdk.Context, request *wasmvmtypes.StakingQuery) ([]byte, error) Stargate func(ctx sdk.Context, request *wasmvmtypes.StargateQuery) ([]byte, error) Wasm func(ctx sdk.Context, request *wasmvmtypes.WasmQuery) ([]byte, error) } -func DefaultQueryPlugins(bank bankkeeper.ViewKeeper, staking stakingkeeper.Keeper, distKeeper distributionkeeper.Keeper, queryRouter GRPCQueryRouter, wasm *Keeper) QueryPlugins { +func DefaultQueryPlugins(bank bankkeeper.ViewKeeper, staking stakingkeeper.Keeper, distKeeper distributionkeeper.Keeper, channelKeeper types.ChannelKeeper, queryRouter GRPCQueryRouter, wasm *Keeper) QueryPlugins { return QueryPlugins{ Bank: BankQuerier(bank), Custom: NoCustomQuerier, + IBC: IBCQuerier(wasm, channelKeeper), Staking: StakingQuerier(staking, distKeeper), Stargate: StargateQuerier(queryRouter), Wasm: WasmQuerier(wasm), @@ -98,6 +113,9 @@ func (e QueryPlugins) Merge(o *QueryPlugins) QueryPlugins { if o.Custom != nil { e.Custom = o.Custom } + if o.IBC != nil { + e.IBC = o.IBC + } if o.Staking != nil { e.Staking = o.Staking } @@ -146,6 +164,67 @@ func NoCustomQuerier(sdk.Context, json.RawMessage) ([]byte, error) { return nil, wasmvmtypes.UnsupportedRequest{Kind: "custom"} } +func IBCQuerier(wasm *Keeper, channelKeeper types.ChannelKeeper) func(ctx sdk.Context, caller sdk.AccAddress, request *wasmvmtypes.IBCQuery) ([]byte, error) { + return func(ctx sdk.Context, caller sdk.AccAddress, request *wasmvmtypes.IBCQuery) ([]byte, error) { + if request.PortID != nil { + contractInfo := wasm.GetContractInfo(ctx, caller) + res := wasmvmtypes.PortIDResponse{ + PortID: contractInfo.IBCPortID, + } + return json.Marshal(res) + } + if request.ListChannels != nil { + portID := request.ListChannels.PortID + var channels wasmvmtypes.IBCEndpoints + channelKeeper.IterateChannels(ctx, func(ch types.IdentifiedChannel) bool { + if portID == "" || portID == ch.PortId { + newChan := wasmvmtypes.IBCEndpoint{ + PortID: ch.PortId, + ChannelID: ch.ChannelId, + } + channels = append(channels, newChan) + } + return false + }) + res := wasmvmtypes.ListChannelsResponse{ + Channels: channels, + } + return json.Marshal(res) + } + if request.Channel != nil { + channelID := request.Channel.ChannelID + portID := request.Channel.PortID + if portID == "" { + contractInfo := wasm.GetContractInfo(ctx, caller) + portID = contractInfo.IBCPortID + } + got, found := channelKeeper.GetChannel(ctx, portID, channelID) + var channel *wasmvmtypes.IBCChannel + if found { + channel = &wasmvmtypes.IBCChannel{ + Endpoint: wasmvmtypes.IBCEndpoint{ + PortID: portID, + ChannelID: channelID, + }, + CounterpartyEndpoint: wasmvmtypes.IBCEndpoint{ + PortID: got.Counterparty.PortId, + ChannelID: got.Counterparty.ChannelId, + }, + Order: got.Ordering.String(), + Version: got.Version, + CounterpartyVersion: "", + ConnectionID: got.ConnectionHops[0], + } + } + res := wasmvmtypes.ChannelResponse{ + Channel: channel, + } + return json.Marshal(res) + } + return nil, wasmvmtypes.UnsupportedRequest{Kind: "unknown IBCQuery variant"} + } +} + func StargateQuerier(queryRouter GRPCQueryRouter) func(ctx sdk.Context, request *wasmvmtypes.StargateQuery) ([]byte, error) { return func(ctx sdk.Context, msg *wasmvmtypes.StargateQuery) ([]byte, error) { route := queryRouter.Route(msg.Path) diff --git a/x/wasm/internal/keeper/relay.go b/x/wasm/internal/keeper/relay.go index ac75381276..3db1194f49 100644 --- a/x/wasm/internal/keeper/relay.go +++ b/x/wasm/internal/keeper/relay.go @@ -23,10 +23,7 @@ func (k Keeper) OnOpenChannel( } env := types.NewEnv(ctx, contractAddr) - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) gas := gasForContract(ctx) gasUsed, execErr := k.wasmer.IBCChannelOpen(codeInfo.CodeHash, env, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) @@ -56,10 +53,7 @@ func (k Keeper) OnConnectChannel( } env := types.NewEnv(ctx, contractAddr) - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) gas := gasForContract(ctx) res, gasUsed, execErr := k.wasmer.IBCChannelConnect(codeInfo.CodeHash, env, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) @@ -95,10 +89,7 @@ func (k Keeper) OnCloseChannel( } params := types.NewEnv(ctx, contractAddr) - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) gas := gasForContract(ctx) res, gasUsed, execErr := k.wasmer.IBCChannelClose(codeInfo.CodeHash, params, channel, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) @@ -134,10 +125,7 @@ func (k Keeper) OnRecvPacket( } env := types.NewEnv(ctx, contractAddr) - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) gas := gasForContract(ctx) res, gasUsed, execErr := k.wasmer.IBCPacketReceive(codeInfo.CodeHash, env, packet, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) @@ -174,10 +162,7 @@ func (k Keeper) OnAckPacket( } env := types.NewEnv(ctx, contractAddr) - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) gas := gasForContract(ctx) res, gasUsed, execErr := k.wasmer.IBCPacketAck(codeInfo.CodeHash, env, acknowledgement, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) @@ -210,10 +195,7 @@ func (k Keeper) OnTimeoutPacket( } env := types.NewEnv(ctx, contractAddr) - querier := QueryHandler{ - Ctx: ctx, - Plugins: k.queryPlugins, - } + querier := NewQueryHandler(ctx, k.queryPlugins, contractAddr) gas := gasForContract(ctx) res, gasUsed, execErr := k.wasmer.IBCPacketTimeout(codeInfo.CodeHash, env, packet, prefixStore, cosmwasmAPI, querier, ctx.GasMeter(), gas) diff --git a/x/wasm/internal/keeper/wasmtesting/mock_keepers.go b/x/wasm/internal/keeper/wasmtesting/mock_keepers.go index d382a321cd..acee3a4082 100644 --- a/x/wasm/internal/keeper/wasmtesting/mock_keepers.go +++ b/x/wasm/internal/keeper/wasmtesting/mock_keepers.go @@ -12,6 +12,7 @@ type MockChannelKeeper struct { GetNextSequenceSendFn func(ctx sdk.Context, portID, channelID string) (uint64, bool) SendPacketFn func(ctx sdk.Context, channelCap *capabilitytypes.Capability, packet ibcexported.PacketI) error ChanCloseInitFn func(ctx sdk.Context, portID, channelID string, chanCap *capabilitytypes.Capability) error + GetAllChannelsFn func(ctx sdk.Context) []channeltypes.IdentifiedChannel } func (m *MockChannelKeeper) GetChannel(ctx sdk.Context, srcPort, srcChan string) (channel channeltypes.Channel, found bool) { @@ -21,6 +22,24 @@ func (m *MockChannelKeeper) GetChannel(ctx sdk.Context, srcPort, srcChan string) return m.GetChannelFn(ctx, srcPort, srcChan) } +func (m *MockChannelKeeper) GetAllChannels(ctx sdk.Context) []channeltypes.IdentifiedChannel { + if m.GetAllChannelsFn == nil { + panic("not supposed to be called!") + } + return m.GetAllChannelsFn(ctx) +} + +// Auto-implemented from GetAllChannels data +func (m *MockChannelKeeper) IterateChannels(ctx sdk.Context, cb func(channeltypes.IdentifiedChannel) bool) { + channels := m.GetAllChannels(ctx) + for _, channel := range channels { + stop := cb(channel) + if stop { + break + } + } +} + func (m *MockChannelKeeper) GetNextSequenceSend(ctx sdk.Context, portID, channelID string) (uint64, bool) { if m.GetNextSequenceSendFn == nil { panic("not supposed to be called!") diff --git a/x/wasm/internal/types/ibc.go b/x/wasm/internal/types/ibc.go index b042254bdb..e4f9d664aa 100644 --- a/x/wasm/internal/types/ibc.go +++ b/x/wasm/internal/types/ibc.go @@ -16,8 +16,12 @@ type ChannelKeeper interface { GetNextSequenceSend(ctx sdk.Context, portID, channelID string) (uint64, bool) SendPacket(ctx sdk.Context, channelCap *capabilitytypes.Capability, packet ibcexported.PacketI) error ChanCloseInit(ctx sdk.Context, portID, channelID string, chanCap *capabilitytypes.Capability) error + GetAllChannels(ctx sdk.Context) (channels []channeltypes.IdentifiedChannel) + IterateChannels(ctx sdk.Context, cb func(channeltypes.IdentifiedChannel) bool) } +type IdentifiedChannel = channeltypes.IdentifiedChannel + // ClientKeeper defines the expected IBC client keeper type ClientKeeper interface { GetClientConsensusState(ctx sdk.Context, clientID string) (connection ibcexported.ConsensusState, found bool)