Skip to content

Commit

Permalink
Refactor export function (#64)
Browse files Browse the repository at this point in the history
* refactor export function

* add method PrepForZeroHeightGenesis to htlc and random

Co-authored-by: 王迪 <wangdi@bianjie.ai>
  • Loading branch information
dgsbl and 王迪 committed Jan 10, 2021
1 parent 4593431 commit d03a90d
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 16 deletions.
23 changes: 16 additions & 7 deletions modules/htlc/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package htlc
import (
"encoding/hex"
"fmt"

tmbytes "github.com/tendermint/tendermint/libs/bytes"

sdk "github.com/cosmos/cosmos-sdk/types"
Expand Down Expand Up @@ -32,18 +31,28 @@ func ExportGenesis(ctx sdk.Context, k keeper.Keeper) *types.GenesisState {

k.IterateHTLCs(ctx, func(hlock tmbytes.HexBytes, h types.HTLC) (stop bool) {
if h.State == types.Open {
h.ExpirationHeight = h.ExpirationHeight - uint64(ctx.BlockHeight()) + 1
pendingHtlcs[hlock.String()] = h
} else if h.State == types.Expired {
if err := k.RefundHTLC(ctx, hlock); err != nil {
panic(fmt.Errorf("failed to export the HTLC genesis state: %s", hlock.String()))
}
}

return false
})

return &types.GenesisState{
PendingHtlcs: pendingHtlcs,
}
}

func PrepForZeroHeightGenesis(ctx sdk.Context, k keeper.Keeper) {
k.IterateHTLCs(
ctx,
func(hlock tmbytes.HexBytes, h types.HTLC) (stop bool) {
if h.State == types.Open {
h.ExpirationHeight = h.ExpirationHeight - uint64(ctx.BlockHeight()) + 1
k.SetHTLC(ctx,h,hlock)
} else if h.State == types.Expired {
if err := k.RefundHTLC(ctx, hlock); err != nil {
panic(fmt.Errorf("failed to export the HTLC genesis state: %s", hlock.String()))
}
}
return false
})
}
13 changes: 11 additions & 2 deletions modules/random/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ func InitGenesis(ctx sdk.Context, k keeper.Keeper, data types.GenesisState) {
func ExportGenesis(ctx sdk.Context, k keeper.Keeper) *types.GenesisState {
pendingRequests := make(map[string]types.Requests)

k.IterateRandomRequestQueue(ctx, func(height int64, request types.Request) bool {
leftHeight := fmt.Sprintf("%d", height-ctx.BlockHeight()+1)
k.IterateRandomRequestQueue(ctx, func(height int64, reqID []byte, request types.Request) bool {
leftHeight := fmt.Sprintf("%d", height)
heightRequests, ok := pendingRequests[leftHeight]
if ok {
heightRequests.Requests = append(heightRequests.Requests, request)
Expand All @@ -43,3 +43,12 @@ func ExportGenesis(ctx sdk.Context, k keeper.Keeper) *types.GenesisState {

return &types.GenesisState{PendingRandomRequests: pendingRequests}
}

func PrepForZeroHeightGenesis(ctx sdk.Context, k keeper.Keeper) {
k.IterateRandomRequestQueue(ctx, func(height int64, reqID []byte, request types.Request) bool {
leftHeight := height-ctx.BlockHeight()+1
k.DequeueRandomRequest(ctx, height, reqID)
k.EnqueueRandomRequest(ctx, leftHeight, reqID, request)
return false
})
}
5 changes: 2 additions & 3 deletions modules/random/genesis_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func (suite *GenesisTestSuite) TestExportGenesis() {

// get the pending requests from queue
storedRequests := make(map[int64][]types.Request)
suite.keeper.IterateRandomRequestQueue(suite.ctx, func(h int64, r types.Request) bool {
suite.keeper.IterateRandomRequestQueue(suite.ctx, func(h int64, reqID []byte, r types.Request) bool {
storedRequests[h] = append(storedRequests[h], r)
return false
})
Expand All @@ -76,7 +76,6 @@ func (suite *GenesisTestSuite) TestExportGenesis() {
// assert that exported requests are consistent with requests in queue
for height, requests := range exportedRequests {
h, _ := strconv.ParseInt(height, 10, 64)
storedHeight := h + testNewHeight - 1
suite.Equal(storedRequests[storedHeight], requests.Requests)
suite.Equal(storedRequests[h], requests.Requests)
}
}
2 changes: 1 addition & 1 deletion modules/random/keeper/grpc_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ func (suite *KeeperTestSuite) TestGRPCRandomRequestQueue() {
suite.Require().NoError(err)
var requests = make([]types.Request, 0)

app.RandomKeeper.IterateRandomRequestQueue(ctx, func(h int64, r types.Request) (stop bool) {
app.RandomKeeper.IterateRandomRequestQueue(ctx, func(h int64, reqID []byte, r types.Request) (stop bool) {
requests = append(requests, r)
return false
})
Expand Down
5 changes: 3 additions & 2 deletions modules/random/keeper/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ func (k Keeper) IterateRandomRequestQueueByHeight(ctx sdk.Context, height int64)
}

// IterateRandomRequestQueue iterates through the random number request queue
func (k Keeper) IterateRandomRequestQueue(ctx sdk.Context, op func(h int64, r types.Request) (stop bool)) {
func (k Keeper) IterateRandomRequestQueue(ctx sdk.Context, op func(h int64, reqID []byte, r types.Request) (stop bool)) {
store := ctx.KVStore(k.storeKey)

iterator := sdk.KVStorePrefixIterator(store, types.PrefixRandomRequestQueue)
Expand All @@ -178,11 +178,12 @@ func (k Keeper) IterateRandomRequestQueue(ctx sdk.Context, op func(h int64, r ty
for ; iterator.Valid(); iterator.Next() {
keyParts := bytes.Split(iterator.Key(), types.KeyDelimiter)
height, _ := strconv.ParseInt(string(keyParts[1]), 10, 64)
reqID := keyParts[2]

var request types.Request
k.cdc.MustUnmarshalBinaryBare(iterator.Value(), &request)

if stop := op(height, request); stop {
if stop := op(height, reqID, request); stop {
break
}
}
Expand Down
2 changes: 1 addition & 1 deletion modules/random/keeper/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func queryRandomRequestQueueByHeight(ctx sdk.Context, height int64, k Keeper) []
func queryAllRandomRequestsInQueue(ctx sdk.Context, k Keeper) []types.Request {
var requests = make([]types.Request, 0)

k.IterateRandomRequestQueue(ctx, func(h int64, r types.Request) (stop bool) {
k.IterateRandomRequestQueue(ctx, func(h int64, reqID []byte, r types.Request) (stop bool) {
requests = append(requests, r)
return false
})
Expand Down

0 comments on commit d03a90d

Please sign in to comment.