diff --git a/.github/workflows/e2e-test.yml b/.github/workflows/tests.yml similarity index 81% rename from .github/workflows/e2e-test.yml rename to .github/workflows/tests.yml index b54f6057f..8800e6dd0 100644 --- a/.github/workflows/e2e-test.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: End to End Test +name: Tests / Code Coverage on: push: @@ -9,7 +9,7 @@ on: pull_request: jobs: - end-to-end-test: + tests: strategy: matrix: go-version: [1.20.x] @@ -44,32 +44,22 @@ jobs: ${{ runner.os }}-go- - name: Setup GitHub Token run: git config --global url.https://$GH_ACCESS_TOKEN@github.com/.insteadOf https://github.com/ - - name: Build + - name: run tests run: | - make build - - name: start e2e local chain - run: | - make e2e_start_localchain - sleep 5 - - name: run e2e test - run: | - make e2e_test - - name: stop e2e local chain - run: | - make e2e_stop_localchain + make test - name: make coverage report id: coverage-report if: github.event_name == 'pull_request' continue-on-error: true run: | - echo '## E2E Test Coverage Report' >> coverage-report.txt + echo '## Test Coverage Report' >> coverage-report.txt echo 'commit-id: ${{ github.event.pull_request.head.sha }}' >> coverage-report.txt echo 'generated by https://github.com/vladopajic/go-test-coverage' >> coverage-report.txt echo >> coverage-report.txt echo '
Additional details and impacted files' >> coverage-report.txt echo >> coverage-report.txt echo '```' >> coverage-report.txt - make check-e2e-coverage >> coverage-report.txt + make check-coverage >> coverage-report.txt echo '```' >> coverage-report.txt echo >> coverage-report.txt echo '
' >> coverage-report.txt diff --git a/.testcoverage.yml b/.testcoverage.yml index 8415c3bba..052ad1dc8 100644 --- a/.testcoverage.yml +++ b/.testcoverage.yml @@ -34,9 +34,15 @@ exclude: - \.pb\.gw\.go$ # excludes all protobuf generated files - .*_mocks.go$ # excludes all protobuf generated files - testutil/.* - - e2e/.* + - e2e/.* + - types/.* + - sdk/.* - internal/sequence/.* - + - x/types/.* + - .*/simulation/.* + - .*/module.go + - .*/module_simulation.go + # NOTES: # - symbol `/` in all path regexps will be replaced by -# current OS file path separator to properly work on Windows \ No newline at end of file +# current OS file path separator to properly work on Windows diff --git a/Makefile b/Makefile index f015a6109..f9653fbf1 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ .PHONY: build build-linux build-macos build-windows -.PHONY: tools proto-gen proto-format test e2e_test ci lint +.PHONY: tools proto-gen proto-format test e2e_init_localchain e2e_test ci lint .PHONY: install-go-test-coverage check-coverage VERSION=$(shell git describe --tags --always) @@ -48,22 +48,23 @@ docker-image: go mod vendor # temporary, should be removed after open source docker build . -t ${IMAGE_NAME} -test: +unit_test: go test -failfast $$(go list ./... | grep -v e2e | grep -v sdk) -e2e_start_localchain: - bash ./deployment/localup/localup.sh all 1 7 +e2e_init_localchain: build + bash ./deployment/localup/localup.sh init 1 7 + bash ./deployment/localup/localup.sh generate 1 7 -e2e_stop_localchain: - bash ./deployment/localup/localup.sh stop +e2e_test: e2e_init_localchain + go test -p 1 -failfast -v ./e2e/... -timeout 99999s -e2e_test: - go test -p 1 -failfast -v ./e2e/... -coverpkg=./... -covermode=atomic -coverprofile=./coverage.out -timeout 99999s +test: e2e_init_localchain + go test -p 1 -failfast -v $$(go list ./... | grep -v sdk) -coverpkg=./... -covermode=atomic -coverprofile=./coverage.out -timeout 99999s install-go-test-coverage: @go install github.com/vladopajic/go-test-coverage/v2@latest -check-e2e-coverage: install-go-test-coverage +check-coverage: install-go-test-coverage @go-test-coverage --config=./.testcoverage.yml || true lint: @@ -73,5 +74,5 @@ lint: proto-gen-check: proto-gen git diff --exit-code -ci: proto-format-check build test e2e_start_localchain e2e_test lint proto-gen-check +ci: proto-format-check build test e2e_init_localchain e2e_test lint proto-gen-check echo "ci passed" diff --git a/e2e/core/basesuite.go b/e2e/core/basesuite.go index e2197e6e6..1bfcfd88e 100644 --- a/e2e/core/basesuite.go +++ b/e2e/core/basesuite.go @@ -9,10 +9,14 @@ import ( "math" "strconv" "strings" + "sync" "time" sdkmath "cosmossdk.io/math" "github.com/cometbft/cometbft/crypto/tmhash" + tmlog "github.com/cometbft/cometbft/libs/log" + sdkClient "github.com/cosmos/cosmos-sdk/client" + sdkServer "github.com/cosmos/cosmos-sdk/server" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" "github.com/cosmos/cosmos-sdk/types/tx" @@ -22,8 +26,10 @@ import ( gov "github.com/cosmos/cosmos-sdk/x/gov/types" govtypesv1 "github.com/cosmos/cosmos-sdk/x/gov/types/v1" "github.com/prysmaticlabs/prysm/crypto/bls" + "github.com/spf13/cobra" "github.com/stretchr/testify/suite" + "github.com/bnb-chain/greenfield/cmd/gnfd/cmd" "github.com/bnb-chain/greenfield/sdk/client" "github.com/bnb-chain/greenfield/sdk/keys" "github.com/bnb-chain/greenfield/sdk/types" @@ -43,6 +49,8 @@ type StorageProvider struct { GlobalVirtualGroupFamilies map[uint32][]*virtualgroupmoduletypes.GlobalVirtualGroup } +var initValidatorOnce sync.Once + type BaseSuite struct { suite.Suite Config *Config @@ -55,8 +63,72 @@ type BaseSuite struct { StorageProviders map[uint32]*StorageProvider } +func findCommand(cmd *cobra.Command, name string) *cobra.Command { + if len(cmd.Commands()) == 0 { + return nil + } + for _, subCmd := range cmd.Commands() { + if subCmd.Name() == name { + return subCmd + } + if found := findCommand(subCmd, name); found != nil { + return found + } + } + + return nil +} + +func (s *BaseSuite) InitChain() { + s.T().Log("Initializing chain") + rootCmd, _ := cmd.NewRootCmd() + // Initialize and start chain + ctx := context.Background() + srvCtx := sdkServer.NewDefaultContext() + ctx = context.WithValue(ctx, sdkClient.ClientContextKey, &sdkClient.Context{}) + ctx = context.WithValue(ctx, sdkServer.ServerContextKey, srvCtx) + + // if you want to debug with chain logs, please discard this + startCmd := findCommand(rootCmd, "start") + startCmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + err := rootCmd.PersistentPreRunE(cmd, args) + if err != nil { + return err + } + ctx := cmd.Context() + serverCtx := sdkServer.GetServerContextFromCmd(cmd) + serverCtx.Logger = tmlog.NewNopLogger() + ctx = context.WithValue(ctx, sdkServer.ServerContextKey, serverCtx) + cmd.SetContext(ctx) + return nil + } + rootCmd.SetArgs([]string{ + "start", + "--home", s.Config.ValidatorHomeDir, + "--rpc.laddr", s.Config.ValidatorTmRPCAddr, + }) + + errChan := make(chan error) + go func() { + errChan <- rootCmd.ExecuteContext(ctx) + }() + + select { + case err := <-errChan: + s.Require().NoError(err) + case <-time.After(15 * time.Second): + // wait 15 seconds for the server to start if no errors + } + + s.T().Log("Chain started") +} + func (s *BaseSuite) SetupSuite() { s.Config = InitConfig() + initValidatorOnce.Do(func() { + s.InitChain() + }) + s.Client, _ = client.NewGreenfieldClient(s.Config.TendermintAddr, s.Config.ChainId) tmClient := client.NewTendermintClient(s.Config.TendermintAddr) s.TmClient = &tmClient diff --git a/e2e/core/config.go b/e2e/core/config.go index f2e59bbe3..4efae3fbd 100644 --- a/e2e/core/config.go +++ b/e2e/core/config.go @@ -26,6 +26,8 @@ type Config struct { SPMnemonics []SPMnemonics `yaml:"SPMnemonics"` SPBLSMnemonic []string `yaml:"SPBLSMnemonic"` Denom string `yaml:"Denom"` + ValidatorHomeDir string `yaml:"ValidatorHomeDir"` + ValidatorTmRPCAddr string `yaml:"ValidatorTmRPCAddr"` } func InitConfig() *Config { @@ -43,6 +45,8 @@ func InitE2eConfig() *Config { ValidatorBlsMnemonic: ParseValidatorBlsMnemonic(0), RelayerMnemonic: ParseRelayerMnemonic(0), ChallengerMnemonic: ParseChallengerMnemonic(0), + ValidatorHomeDir: ParseValidatorHomeDir(0), + ValidatorTmRPCAddr: ParseValidatorTmRPCAddrDir(0), } for i := 0; i < 7; i++ { config.SPMnemonics = append(config.SPMnemonics, ParseSPMnemonics(i)) @@ -105,3 +109,13 @@ func ParseMnemonicFromFile(fileName string) string { } return line } + +// ParseValidatorHomeDir returns the home dir of the validator +func ParseValidatorHomeDir(i int) string { + return fmt.Sprintf("../../deployment/localup/.local/validator%d", i) +} + +// ParseValidatorTmRPCAddrDir returns the home dir of the validator +func ParseValidatorTmRPCAddrDir(i int) string { + return fmt.Sprintf("tcp://0.0.0.0:%d", 26750+i) +} diff --git a/e2e/tests/bridge_test.go b/e2e/tests/bridge_test.go index 4d6df2237..7795a9a44 100644 --- a/e2e/tests/bridge_test.go +++ b/e2e/tests/bridge_test.go @@ -8,9 +8,12 @@ import ( "time" sdkmath "cosmossdk.io/math" + sdkClient "github.com/cosmos/cosmos-sdk/client" "github.com/cosmos/cosmos-sdk/types" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/tx" + "github.com/cosmos/cosmos-sdk/types/tx/signing" + authtx "github.com/cosmos/cosmos-sdk/x/auth/tx" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" banktypes "github.com/cosmos/cosmos-sdk/x/bank/types" crosschaintypes "github.com/cosmos/cosmos-sdk/x/crosschain/types" @@ -23,6 +26,7 @@ import ( "github.com/bnb-chain/greenfield/e2e/core" gnfdtypes "github.com/bnb-chain/greenfield/sdk/types" types2 "github.com/bnb-chain/greenfield/sdk/types" + "github.com/bnb-chain/greenfield/x/bridge/client/cli" bridgetypes "github.com/bnb-chain/greenfield/x/bridge/types" ) @@ -36,6 +40,44 @@ func (s *BridgeTestSuite) SetupSuite() { func (s *BridgeTestSuite) SetupTest() {} +func (s *BridgeTestSuite) TestCliQuery() { + ctx := context.Background() + cliCtx := &sdkClient.Context{Client: s.TmClient.TmClient, Codec: s.Client.GetCodec()} + ctx = context.WithValue(ctx, sdkClient.ClientContextKey, cliCtx) + queryCmd := cli.GetQueryCmd() + + // query params + queryCmd.SetArgs([]string{"params"}) + s.Require().NoError(queryCmd.ExecuteContext(ctx)) +} + +func (s *BridgeTestSuite) TestCliTx() { + ctx := context.Background() + txCfg := authtx.NewTxConfig(s.Client.GetCodec(), []signing.SignMode{signing.SignMode_SIGN_MODE_EIP_712}) + cliCtx := &sdkClient.Context{ + FromAddress: s.Validator.GetAddr(), + Client: s.TmClient.TmClient, + InterfaceRegistry: s.Client.GetCodec().InterfaceRegistry(), + Codec: s.Client.GetCodec(), + From: s.Validator.String(), + AccountRetriever: authtypes.AccountRetriever{}, + ChainID: s.Config.ChainId, + TxConfig: txCfg, + SkipConfirm: true, + Simulate: true, + } + ctx = context.WithValue(ctx, sdkClient.ClientContextKey, cliCtx) + txCmd := cli.GetTxCmd() + + // wrong to address + txCmd.SetArgs([]string{"transfer-out", "test", "1000000000000000000BNB"}) + s.Require().Contains(txCmd.ExecuteContext(ctx).Error(), "invalid address hex length") + + // tx transfer-out + txCmd.SetArgs([]string{"transfer-out", sdk.AccAddress("test").String(), "1000000000000000000BNB"}) + s.Require().NoError(txCmd.ExecuteContext(ctx)) +} + func (s *BridgeTestSuite) TestTransferOut() { users := s.GenAndChargeAccounts(2, 1000000) diff --git a/e2e/tests/payment_test.go b/e2e/tests/payment_test.go index b3fb6fc18..7ceaa3bda 100644 --- a/e2e/tests/payment_test.go +++ b/e2e/tests/payment_test.go @@ -3009,6 +3009,8 @@ func (s *PaymentTestSuite) TestDiscontinue_InBlocks_WithPriceChangeReserveTimeCh s.Require().Equal(queryHeadObjectResponse.ObjectInfo.ObjectStatus, storagetypes.OBJECT_STATUS_CREATED) time.Sleep(200 * time.Millisecond) } + userStreamRecord := s.getStreamRecord(user.GetAddr().String()) + s.Require().True(userStreamRecord.LockBalance.IsPositive()) // update new price msgUpdatePrice := &sptypes.MsgUpdateSpStoragePrice{ @@ -3107,6 +3109,9 @@ func (s *PaymentTestSuite) TestDiscontinue_InBlocks_WithPriceChangeReserveTimeCh s.Require().Equal(streamRecordsAfter.GVG.NetflowRate.Sub(streamRecordsBefore.GVG.NetflowRate).Int64(), int64(0)) s.Require().True(streamRecordsAfter.Tax.NetflowRate.Sub(streamRecordsBefore.Tax.NetflowRate).Int64() <= int64(0)) // there are other auto settling + s.Require().True(streamRecordsAfter.User.LockBalance.IsZero()) + s.Require().True(streamRecordsAfter.User.StaticBalance.Int64() == userStreamRecord.LockBalance.Int64()) + // revert price msgUpdatePrice = &sptypes.MsgUpdateSpStoragePrice{ SpAddress: sp.OperatorKey.GetAddr().String(), diff --git a/e2e/tests/virtualgroup_test.go b/e2e/tests/virtualgroup_test.go index a39d39e00..ee8e20535 100644 --- a/e2e/tests/virtualgroup_test.go +++ b/e2e/tests/virtualgroup_test.go @@ -65,6 +65,16 @@ func (s *VirtualGroupTestSuite) queryGlobalVirtualGroupByFamily(familyID uint32) return resp.GlobalVirtualGroups } +func (s *VirtualGroupTestSuite) queryAvailableGlobalVirtualGroupFamilies(familyIds []uint32) []uint32 { + resp, err := s.Client.AvailableGlobalVirtualGroupFamilies( + context.Background(), + &virtualgroupmoduletypes.AvailableGlobalVirtualGroupFamiliesRequest{ + GlobalVirtualGroupFamilyIds: familyIds, + }) + s.Require().NoError(err) + return resp.GlobalVirtualGroupFamilyIds +} + func (s *VirtualGroupTestSuite) TestBasic() { primarySP := s.BaseSuite.PickStorageProvider() @@ -82,6 +92,9 @@ func (s *VirtualGroupTestSuite) TestBasic() { gvg = g } + availableGvgFamilyIds := s.queryAvailableGlobalVirtualGroupFamilies([]uint32{gvg.FamilyId}) + s.Require().Equal(availableGvgFamilyIds[0], gvg.FamilyId) + srcGVGs := s.queryGlobalVirtualGroupByFamily(gvg.FamilyId) var secondarySPIDs []uint32 diff --git a/testutil/sample/sample.go b/testutil/sample/sample.go index 96aaf945c..be3f0b796 100644 --- a/testutil/sample/sample.go +++ b/testutil/sample/sample.go @@ -5,22 +5,35 @@ import ( "encoding/hex" "github.com/cometbft/cometbft/crypto/tmhash" - "github.com/cosmos/cosmos-sdk/crypto/keys/ed25519" + "github.com/cosmos/cosmos-sdk/crypto/keys/eth/ethsecp256k1" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/prysmaticlabs/prysm/crypto/bls" ) -// AccAddress returns a sample account address -func AccAddress() string { - pk := ed25519.GenPrivKey().PubKey() - addr := pk.Address() - return sdk.AccAddress(addr).String() +func RandAccAddress() sdk.AccAddress { + pk, err := ethsecp256k1.GenPrivKey() + if err != nil { + panic(err) + } + return sdk.AccAddress(pk.PubKey().Address()) } -func RandAccAddress() sdk.AccAddress { - pk := ed25519.GenPrivKey().PubKey() - addr := pk.Address() - return sdk.AccAddress(addr) +func RandAccAddressHex() string { + pk, err := ethsecp256k1.GenPrivKey() + if err != nil { + panic(err) + } + return sdk.AccAddress(pk.PubKey().Address()).String() +} + +func RandSignBytes() (addr sdk.AccAddress, signBytes, sig []byte) { + signBytes = RandStr(256) + privKey, _ := ethsecp256k1.GenPrivKey() + + sig, _ = privKey.Sign(sdk.Keccak256(signBytes)) + pk := privKey.PubKey() + addr = sdk.AccAddress(pk.Address()) + return addr, signBytes, sig } func Checksum() []byte { diff --git a/types/grn_test.go b/types/grn_test.go index f85fd1e9d..f5dd6a33b 100644 --- a/types/grn_test.go +++ b/types/grn_test.go @@ -17,7 +17,7 @@ func TestGRNBasic(t *testing.T) { var grn types3.GRN testBucketName := storageutils.GenRandomBucketName() testObjectName := storageutils.GenRandomObjectName() - testAcc := sample.AccAddress() + testAcc := sample.RandAccAddressHex() testGroupName := storageutils.GenRandomGroupName() err := grn.ParseFromString("grn:b::"+testBucketName, false) diff --git a/x/bridge/client/cli/query.go b/x/bridge/client/cli/query.go index f15f54f23..dbdb274b9 100644 --- a/x/bridge/client/cli/query.go +++ b/x/bridge/client/cli/query.go @@ -10,7 +10,7 @@ import ( ) // GetQueryCmd returns the cli query commands for this module -func GetQueryCmd(queryRoute string) *cobra.Command { +func GetQueryCmd() *cobra.Command { // Group bridge queries under a subcommand cmd := &cobra.Command{ Use: types.ModuleName, diff --git a/x/bridge/keeper/cross_app.go b/x/bridge/keeper/cross_app.go index de0a4dc08..15c3e4e16 100644 --- a/x/bridge/keeper/cross_app.go +++ b/x/bridge/keeper/cross_app.go @@ -2,11 +2,9 @@ package keeper import ( "encoding/hex" - "math/big" "cosmossdk.io/errors" sdk "github.com/cosmos/cosmos-sdk/types" - sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" crosschaintypes "github.com/cosmos/cosmos-sdk/x/crosschain/types" "github.com/bnb-chain/greenfield/x/bridge/types" @@ -38,17 +36,6 @@ func NewTransferOutApp(keeper Keeper) *TransferOutApp { } } -func (app *TransferOutApp) CheckPackage(refundPackage *types.TransferOutRefundPackage) error { - if refundPackage.RefundAddr.Empty() { - return errors.Wrapf(sdkerrors.ErrInvalidAddress, "refund address is empty") - } - - if refundPackage.RefundAmount.Cmp(big.NewInt(0)) < 0 { - return errors.Wrapf(types.ErrInvalidAmount, "amount to refund should not be negative") - } - return nil -} - func (app *TransferOutApp) ExecuteAckPackage(ctx sdk.Context, appCtx *sdk.CrossChainAppContext, payload []byte) sdk.ExecuteResult { if len(payload) == 0 { return sdk.ExecuteResult{} @@ -58,22 +45,14 @@ func (app *TransferOutApp) ExecuteAckPackage(ctx sdk.Context, appCtx *sdk.CrossC refundPackage, err := types.DeserializeTransferOutRefundPackage(payload) if err != nil { - app.bridgeKeeper.Logger(ctx).Error("unmarshal transfer out refund claim error", "err", err.Error(), "claim", hex.EncodeToString(payload)) - return sdk.ExecuteResult{ - Err: err, - } - } - - err = app.CheckPackage(refundPackage) - if err != nil { - app.bridgeKeeper.Logger(ctx).Error("check transfer out refund package error", "err", err.Error(), "claim", hex.EncodeToString(payload)) + app.bridgeKeeper.Logger(ctx).Error("decode transfer out refund claim error", "err", err.Error(), "claim", hex.EncodeToString(payload)) return sdk.ExecuteResult{ Err: err, } } denom := app.bridgeKeeper.stakingKeeper.BondDenom(ctx) // only support native token so far - err = app.bridgeKeeper.bankKeeper.SendCoinsFromModuleToAccount(ctx, crosschaintypes.ModuleName, refundPackage.RefundAddr, + err = app.bridgeKeeper.bankKeeper.SendCoinsFromModuleToAccount(ctx, crosschaintypes.ModuleName, refundPackage.RefundAddress, sdk.Coins{ sdk.Coin{ Denom: denom, @@ -89,7 +68,7 @@ func (app *TransferOutApp) ExecuteAckPackage(ctx sdk.Context, appCtx *sdk.CrossC } err = ctx.EventManager().EmitTypedEvent(&types.EventCrossTransferOutRefund{ - RefundAddress: refundPackage.RefundAddr.String(), + RefundAddress: refundPackage.RefundAddress.String(), Amount: &sdk.Coin{ Denom: denom, Amount: sdk.NewIntFromBigInt(refundPackage.RefundAmount), @@ -116,12 +95,6 @@ func (app *TransferOutApp) ExecuteFailAckPackage(ctx sdk.Context, appCtx *sdk.Cr } } - if transferOutPackage.Amount.Cmp(big.NewInt(0)) < 0 { - return sdk.ExecuteResult{ - Err: errors.Wrapf(types.ErrInvalidAmount, "amount to refund should not be negative"), - } - } - denom := app.bridgeKeeper.stakingKeeper.BondDenom(ctx) // only support native token so far err = app.bridgeKeeper.bankKeeper.SendCoinsFromModuleToAccount(ctx, crosschaintypes.ModuleName, transferOutPackage.RefundAddress, sdk.Coins{ @@ -173,22 +146,6 @@ func NewTransferInApp(bridgeKeeper Keeper) *TransferInApp { } } -func (app *TransferInApp) CheckTransferInSynPackage(transferInPackage *types.TransferInSynPackage) error { - if transferInPackage.Amount.Cmp(big.NewInt(0)) < 0 { - return errors.Wrapf(types.ErrInvalidAmount, "amount should not be negative") - } - - if transferInPackage.ReceiverAddress.Empty() { - return errors.Wrapf(sdkerrors.ErrInvalidAddress, "receiver address should not be empty") - } - - if transferInPackage.RefundAddress.Empty() { - return errors.Wrapf(types.ErrInvalidAddress, "refund address should not be empty") - } - - return nil -} - func (app *TransferInApp) ExecuteAckPackage(ctx sdk.Context, header *sdk.CrossChainAppContext, payload []byte) sdk.ExecuteResult { app.bridgeKeeper.Logger(ctx).Error("received transfer in ack package", "payload", hex.EncodeToString(payload)) return sdk.ExecuteResult{} @@ -202,14 +159,8 @@ func (app *TransferInApp) ExecuteFailAckPackage(ctx sdk.Context, header *sdk.Cro func (app *TransferInApp) ExecuteSynPackage(ctx sdk.Context, appCtx *sdk.CrossChainAppContext, payload []byte) sdk.ExecuteResult { transferInPackage, err := types.DeserializeTransferInSynPackage(payload) if err != nil { - app.bridgeKeeper.Logger(ctx).Error("unmarshal transfer in claim error", "err", err.Error(), "claim", string(payload)) - panic("unmarshal transfer in claim error") - } - - err = app.CheckTransferInSynPackage(transferInPackage) - if err != nil { - app.bridgeKeeper.Logger(ctx).Error("check transfer in package error", "err", err.Error(), "claim", string(payload)) - panic(err) + app.bridgeKeeper.Logger(ctx).Error("decode transfer in claim error", "err", err.Error(), "claim", string(payload)) + panic("decode transfer in claim error") } denom := app.bridgeKeeper.stakingKeeper.BondDenom(ctx) @@ -218,14 +169,14 @@ func (app *TransferInApp) ExecuteSynPackage(ctx sdk.Context, appCtx *sdk.CrossCh err = app.bridgeKeeper.bankKeeper.SendCoinsFromModuleToAccount(ctx, crosschaintypes.ModuleName, transferInPackage.ReceiverAddress, sdk.Coins{amount}) if err != nil { app.bridgeKeeper.Logger(ctx).Error("send coins error", "err", err.Error()) - refundPackage, err := app.bridgeKeeper.GetRefundTransferInPayload(transferInPackage, uint32(types.REFUND_REASON_INSUFFICIENT_BALANCE)) - if err != nil { - app.bridgeKeeper.Logger(ctx).Error("get refund transfer in payload error", "err", err.Error()) - panic(err) + refundPackage, refundErr := app.bridgeKeeper.GetRefundTransferInPayload(transferInPackage, uint32(types.REFUND_REASON_INSUFFICIENT_BALANCE)) + if refundErr != nil { + app.bridgeKeeper.Logger(ctx).Error("get refund transfer in payload error", "err", refundErr.Error()) + panic(refundErr) } return sdk.ExecuteResult{ Payload: refundPackage, - Err: errors.Wrapf(types.ErrInvalidLength, "balance of cross chain module is insufficient"), + Err: errors.Wrapf(types.ErrInvalidPackage, "send coins error: %s", err.Error()), } } diff --git a/x/bridge/keeper/cross_app_test.go b/x/bridge/keeper/cross_app_test.go index beee21139..970f5bffe 100644 --- a/x/bridge/keeper/cross_app_test.go +++ b/x/bridge/keeper/cross_app_test.go @@ -1,74 +1,22 @@ package keeper_test import ( - "bytes" + "fmt" "math/big" - "testing" - "github.com/cosmos/cosmos-sdk/crypto/hd" - "github.com/cosmos/cosmos-sdk/testutil" sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/ethereum/go-ethereum/accounts/abi" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" "github.com/bnb-chain/greenfield/x/bridge/keeper" "github.com/bnb-chain/greenfield/x/bridge/types" ) -func TestTransferOutCheck(t *testing.T) { - tests := []struct { - refundPackage types.TransferOutRefundPackage - expectedPass bool - errorMsg string - }{ - { - refundPackage: types.TransferOutRefundPackage{ - RefundAmount: big.NewInt(1), - RefundAddr: []byte{}, - RefundReason: 0, - }, - expectedPass: false, - errorMsg: "refund address is empty", - }, - { - refundPackage: types.TransferOutRefundPackage{ - RefundAmount: big.NewInt(-1), - RefundAddr: bytes.Repeat([]byte{1}, 20), - RefundReason: 0, - }, - expectedPass: false, - errorMsg: "amount to refund should not be negative", - }, - { - refundPackage: types.TransferOutRefundPackage{ - RefundAmount: big.NewInt(1), - RefundAddr: bytes.Repeat([]byte{1}, 20), - RefundReason: 0, - }, - expectedPass: true, - }, - } - - crossApp := keeper.NewTransferOutApp(keeper.Keeper{}) - for _, test := range tests { - err := crossApp.CheckPackage(&test.refundPackage) - if test.expectedPass { - require.Nil(t, err, "error should be nil") - } else { - require.NotNil(t, err, " error should not be nil") - require.Contains(t, err.Error(), test.errorMsg) - } - } -} - func (s *TestSuite) TestTransferOutAck() { - addr1, _, err := testutil.GenerateCoinKey(hd.Secp256k1, s.cdc) - s.Require().Nil(err, "generate key failed") - refundPackage := types.TransferOutRefundPackage{ - RefundAmount: big.NewInt(1), - RefundAddr: addr1, - RefundReason: 1, + RefundAmount: big.NewInt(1), + RefundAddress: sdk.AccAddress("refundAddress"), + RefundReason: 1, } packageBytes, err := refundPackage.Serialize() @@ -77,20 +25,32 @@ func (s *TestSuite) TestTransferOutAck() { transferOutApp := keeper.NewTransferOutApp(*s.bridgeKeeper) s.stakingKeeper.EXPECT().BondDenom(gomock.Any()).Return("BNB").AnyTimes() - s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - result := transferOutApp.ExecuteAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, packageBytes) + // empty payload + result := transferOutApp.ExecuteAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, nil) + s.Require().Nil(result.Err, "result should be nil") + s.Require().Nil(result.Payload, "result should be nil") + + // wrong payload + result = transferOutApp.ExecuteAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, []byte{1}) + s.Require().Contains(result.Err.Error(), "deserialize transfer out refund package failed") + + // send coins failed + s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("test send coins error")).Times(1) + result = transferOutApp.ExecuteAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, packageBytes) + s.Require().Contains(result.Err.Error(), "test send coins error") + + // success case + s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) + result = transferOutApp.ExecuteAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, packageBytes) s.Require().Nil(err, result.Err, "error should be nil") } -func (s *TestSuite) TestTransferOutFailAck() { - addr1, _, err := testutil.GenerateCoinKey(hd.Secp256k1, s.cdc) - s.Require().Nil(err, "generate key failed") - +func (s *TestSuite) TestTransferOutSynAndFailAck() { synPackage := types.TransferOutSynPackage{ Amount: big.NewInt(1), Recipient: sdk.AccAddress{}, - RefundAddress: addr1, + RefundAddress: sdk.AccAddress("refundAddress"), } packageBytes, err := synPackage.Serialize() @@ -100,75 +60,32 @@ func (s *TestSuite) TestTransferOutFailAck() { s.crossChainKeeper.EXPECT().CreateRawIBCPackageWithFee(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(uint64(0), nil).AnyTimes() s.stakingKeeper.EXPECT().BondDenom(gomock.Any()).Return("BNB").AnyTimes() - s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() - result := transferOutApp.ExecuteFailAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, packageBytes) - s.Require().Nil(err, result.Err, "error should be nil") -} + // syn package + result := transferOutApp.ExecuteSynPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, nil) + s.Require().Nil(result.Payload, "result should be nil") -func TestTransferInCheck(t *testing.T) { - tests := []struct { - transferInPackage types.TransferInSynPackage - expectedPass bool - errorMsg string - }{ - { - transferInPackage: types.TransferInSynPackage{ - Amount: big.NewInt(1), - ReceiverAddress: sdk.AccAddress{}, - RefundAddress: sdk.AccAddress{1}, - }, - expectedPass: false, - errorMsg: "receiver address should not be empty", - }, - { - transferInPackage: types.TransferInSynPackage{ - Amount: big.NewInt(1), - ReceiverAddress: sdk.AccAddress{1}, - RefundAddress: sdk.AccAddress{}, - }, - expectedPass: false, - errorMsg: "refund address should not be empty", - }, - { - transferInPackage: types.TransferInSynPackage{ - Amount: big.NewInt(-1), - ReceiverAddress: sdk.AccAddress{1}, - RefundAddress: sdk.AccAddress{1}, - }, - expectedPass: false, - errorMsg: "amount should not be negative", - }, - { - transferInPackage: types.TransferInSynPackage{ - Amount: big.NewInt(1), - ReceiverAddress: sdk.AccAddress{1}, - RefundAddress: sdk.AccAddress{1}, - }, - expectedPass: true, - }, - } + // fail ack package + // wrong payload + result = transferOutApp.ExecuteFailAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, []byte{1}) + s.Require().Contains(result.Err.Error(), "deserialize transfer out syn package failed") - crossApp := keeper.NewTransferInApp(keeper.Keeper{}) - for _, test := range tests { - err := crossApp.CheckTransferInSynPackage(&test.transferInPackage) - if test.expectedPass { - require.Nil(t, err, "error should be nil") - } else { - require.NotNil(t, err, " error should not be nil") - require.Contains(t, err.Error(), test.errorMsg) - } - } -} + // send coins failed + s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("test send coins error")).Times(1) + result = transferOutApp.ExecuteFailAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, packageBytes) + s.Require().Contains(result.Err.Error(), "send coins error") -func (s *TestSuite) TestTransferInSyn() { - addr1, _, err := testutil.GenerateCoinKey(hd.Secp256k1, s.cdc) - s.Require().Nil(err, "generate key failed") + // success case + s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) + result = transferOutApp.ExecuteFailAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, packageBytes) + s.Require().Nil(err, result.Err, "error should be nil") +} +func (s *TestSuite) TestTransferIn() { transferInSynPackage := types.TransferInSynPackage{ Amount: big.NewInt(1), - ReceiverAddress: addr1, - RefundAddress: sdk.AccAddress{1}, + ReceiverAddress: sdk.AccAddress("receiverAddress"), + RefundAddress: sdk.AccAddress("refundAddress"), } packageBytes, err := transferInSynPackage.Serialize() @@ -178,8 +95,35 @@ func (s *TestSuite) TestTransferInSyn() { s.crossChainKeeper.EXPECT().CreateRawIBCPackageWithFee(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(uint64(0), nil).AnyTimes() s.stakingKeeper.EXPECT().BondDenom(gomock.Any()).Return("BNB").AnyTimes() - s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).AnyTimes() + // syn package + // wrong payload + s.Require().Panics(func() { transferInApp.ExecuteSynPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, []byte{1}) }) + + // send coins failed and refund + s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("test send coins error")).Times(1) result := transferInApp.ExecuteSynPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, packageBytes) - s.Require().Nil(err, result.Err, "error should be nil") + s.Require().Contains(result.Err.Error(), "test send coins error") + unpacked, err := types.TransferInRefundPackageArgs.Unpack(result.Payload) + s.Require().NoError(err) + + unpackedStruct := abi.ConvertType(unpacked[0], types.TransferInRefundPackageStruct{}) + pkgStruct, ok := unpackedStruct.(types.TransferInRefundPackageStruct) + s.Require().True(ok) + + s.Require().Equal(transferInSynPackage.Amount, pkgStruct.RefundAmount) + s.Require().Equal(transferInSynPackage.RefundAddress.String(), pkgStruct.RefundAddress.String()) + s.Require().Equal(uint32(types.REFUND_REASON_INSUFFICIENT_BALANCE), pkgStruct.RefundReason) + + // success case + s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil).Times(1) + result = transferInApp.ExecuteSynPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, packageBytes) + s.Require().Nil(result.Err, "error should be nil") + + // unexpected package type + result = transferInApp.ExecuteAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, nil) + s.Require().Nil(result.Payload, "result should be nil") + + result = transferInApp.ExecuteFailAckPackage(s.ctx, &sdk.CrossChainAppContext{Sequence: 1}, nil) + s.Require().Nil(result.Payload, "result should be nil") } diff --git a/x/bridge/genesis.go b/x/bridge/keeper/genesis.go similarity index 64% rename from x/bridge/genesis.go rename to x/bridge/keeper/genesis.go index 270b49939..58b9bb91c 100644 --- a/x/bridge/genesis.go +++ b/x/bridge/keeper/genesis.go @@ -1,14 +1,13 @@ -package bridge +package keeper import ( sdk "github.com/cosmos/cosmos-sdk/types" - "github.com/bnb-chain/greenfield/x/bridge/keeper" "github.com/bnb-chain/greenfield/x/bridge/types" ) // InitGenesis initializes the module's state from a provided genesis state. -func InitGenesis(ctx sdk.Context, k keeper.Keeper, genState types.GenesisState) { +func InitGenesis(ctx sdk.Context, k Keeper, genState types.GenesisState) { err := k.SetParams(ctx, genState.Params) if err != nil { panic(err) @@ -16,7 +15,7 @@ func InitGenesis(ctx sdk.Context, k keeper.Keeper, genState types.GenesisState) } // ExportGenesis returns the module's exported genesis -func ExportGenesis(ctx sdk.Context, k keeper.Keeper) *types.GenesisState { +func ExportGenesis(ctx sdk.Context, k Keeper) *types.GenesisState { genesis := types.DefaultGenesis() genesis.Params = k.GetParams(ctx) diff --git a/x/bridge/keeper/genesis_test.go b/x/bridge/keeper/genesis_test.go new file mode 100644 index 000000000..4bf996edf --- /dev/null +++ b/x/bridge/keeper/genesis_test.go @@ -0,0 +1,29 @@ +package keeper_test + +import ( + sdkmath "cosmossdk.io/math" + + "github.com/bnb-chain/greenfield/x/bridge/keeper" + "github.com/bnb-chain/greenfield/x/bridge/types" +) + +func (s *TestSuite) TestExportGenesis() { + ctx := s.ctx + + s.Require().NoError(s.bridgeKeeper.SetParams(ctx, types.DefaultParams())) + exportGenesis := keeper.ExportGenesis(ctx, *s.bridgeKeeper) + + s.Require().Equal(types.DefaultParams().BscTransferOutRelayerFee, exportGenesis.Params.BscTransferOutRelayerFee) + s.Require().Equal(types.DefaultParams().BscTransferOutAckRelayerFee, exportGenesis.Params.BscTransferOutAckRelayerFee) +} + +func (s *TestSuite) TestInitGenesis() { + g := types.DefaultGenesis() + k := s.bridgeKeeper + keeper.InitGenesis(s.ctx, *k, *g) + + // Check that the genesis state was set correctly. + params := k.GetParams(s.ctx) + s.Require().Equal(sdkmath.NewInt(250000000000000), params.BscTransferOutRelayerFee) + s.Require().Equal(sdkmath.NewInt(0), params.BscTransferOutAckRelayerFee) +} diff --git a/x/bridge/module.go b/x/bridge/module.go index 0c6037851..42ba17b2c 100644 --- a/x/bridge/module.go +++ b/x/bridge/module.go @@ -81,7 +81,7 @@ func (a AppModuleBasic) GetTxCmd() *cobra.Command { // GetQueryCmd returns the root query command for the module. The subcommands of this root command are used by end-users to generate new queries to the subset of the state defined by the module func (AppModuleBasic) GetQueryCmd() *cobra.Command { - return cli.GetQueryCmd(types.StoreKey) + return cli.GetQueryCmd() } // ---------------------------------------------------------------------------- @@ -126,14 +126,14 @@ func (am AppModule) InitGenesis(ctx sdk.Context, cdc codec.JSONCodec, gs json.Ra // Initialize global index to index in genesis state cdc.MustUnmarshalJSON(gs, &genState) - InitGenesis(ctx, am.keeper, genState) + keeper.InitGenesis(ctx, am.keeper, genState) return []abci.ValidatorUpdate{} } // ExportGenesis returns the module's exported genesis state as raw JSON bytes. func (am AppModule) ExportGenesis(ctx sdk.Context, cdc codec.JSONCodec) json.RawMessage { - genState := ExportGenesis(ctx, am.keeper) + genState := keeper.ExportGenesis(ctx, am.keeper) return cdc.MustMarshalJSON(genState) } diff --git a/x/bridge/module_simulation.go b/x/bridge/module_simulation.go index b8881daa3..86d4388c9 100644 --- a/x/bridge/module_simulation.go +++ b/x/bridge/module_simulation.go @@ -16,7 +16,7 @@ import ( // avoid unused import issue var ( - _ = sample.AccAddress + _ = sample.RandAccAddressHex _ = bridgesimulation.FindAccount _ = simulation.MsgEntryKind _ = baseapp.Paramspace diff --git a/x/bridge/types/message_transfer_out_test.go b/x/bridge/types/message_transfer_out_test.go index 97ed5bf8f..f75f0936a 100644 --- a/x/bridge/types/message_transfer_out_test.go +++ b/x/bridge/types/message_transfer_out_test.go @@ -25,7 +25,7 @@ func TestMsgTransferOut_ValidateBasic(t *testing.T) { }, { name: "invalid to address", msg: MsgTransferOut{ - From: sample.AccAddress(), + From: sample.RandAccAddressHex(), To: "invalid address", }, err: sdkerrors.ErrInvalidAddress, @@ -33,7 +33,7 @@ func TestMsgTransferOut_ValidateBasic(t *testing.T) { { name: "invalid amount", msg: MsgTransferOut{ - From: sample.AccAddress(), + From: sample.RandAccAddressHex(), To: "0x0000000000000000000000000000000000001000", Amount: nil, }, @@ -42,7 +42,7 @@ func TestMsgTransferOut_ValidateBasic(t *testing.T) { { name: "invalid amount", msg: MsgTransferOut{ - From: sample.AccAddress(), + From: sample.RandAccAddressHex(), To: "0x0000000000000000000000000000000000001000", Amount: &sdk.Coin{ Denom: "%%%%%", @@ -54,7 +54,7 @@ func TestMsgTransferOut_ValidateBasic(t *testing.T) { { name: "invalid amount", msg: MsgTransferOut{ - From: sample.AccAddress(), + From: sample.RandAccAddressHex(), To: "0x0000000000000000000000000000000000001000", Amount: &sdk.Coin{ Denom: "coin", @@ -66,7 +66,7 @@ func TestMsgTransferOut_ValidateBasic(t *testing.T) { { name: "invalid amount", msg: MsgTransferOut{ - From: sample.AccAddress(), + From: sample.RandAccAddressHex(), To: "0x0000000000000000000000000000000000001000", Amount: &sdk.Coin{ Denom: "coin", @@ -78,7 +78,7 @@ func TestMsgTransferOut_ValidateBasic(t *testing.T) { { name: "invalid amount", msg: MsgTransferOut{ - From: sample.AccAddress(), + From: sample.RandAccAddressHex(), To: "0x0000000000000000000000000000000000001000", Amount: &sdk.Coin{ Denom: "coin", diff --git a/x/bridge/types/types.go b/x/bridge/types/types.go index 97046a951..37847b670 100644 --- a/x/bridge/types/types.go +++ b/x/bridge/types/types.go @@ -39,19 +39,19 @@ type TransferOutSynPackageStruct struct { } var ( - transferOutSynPackageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{ + TransferOutSynPackageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{ {Name: "Amount", Type: "uint256"}, {Name: "Recipient", Type: "address"}, {Name: "RefundAddress", Type: "address"}, }) - transferOutSynPackageArgs = abi.Arguments{ - {Type: transferOutSynPackageType}, + TransferOutSynPackageArgs = abi.Arguments{ + {Type: TransferOutSynPackageType}, } ) func (pkg *TransferOutSynPackage) Serialize() ([]byte, error) { - return transferOutSynPackageArgs.Pack(&TransferOutSynPackageStruct{ + return TransferOutSynPackageArgs.Pack(&TransferOutSynPackageStruct{ SafeBigInt(pkg.Amount), common.BytesToAddress(pkg.Recipient), common.BytesToAddress(pkg.RefundAddress), @@ -59,9 +59,9 @@ func (pkg *TransferOutSynPackage) Serialize() ([]byte, error) { } func DeserializeTransferOutSynPackage(serializedPackage []byte) (*TransferOutSynPackage, error) { - unpacked, err := transferOutSynPackageArgs.Unpack(serializedPackage) + unpacked, err := TransferOutSynPackageArgs.Unpack(serializedPackage) if err != nil { - return nil, errors.Wrapf(ErrInvalidPackage, "deserialize transfer out sync package failed") + return nil, errors.Wrapf(ErrInvalidPackage, "deserialize transfer out syn package failed") } unpackedStruct := abi.ConvertType(unpacked[0], TransferOutSynPackageStruct{}) @@ -79,39 +79,43 @@ func DeserializeTransferOutSynPackage(serializedPackage []byte) (*TransferOutSyn } type TransferOutRefundPackage struct { - RefundAmount *big.Int - RefundAddr sdk.AccAddress - RefundReason uint32 + RefundAmount *big.Int + RefundAddress sdk.AccAddress + RefundReason uint32 } type TransferOutRefundPackageStruct struct { - RefundAmount *big.Int - RefundAddr common.Address - RefundReason uint32 + RefundAmount *big.Int + RefundAddress common.Address + RefundReason uint32 } var ( - transferOutRefundPackageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{ + TransferOutRefundPackageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{ {Name: "RefundAmount", Type: "uint256"}, - {Name: "RefundAddr", Type: "address"}, + {Name: "RefundAddress", Type: "address"}, {Name: "RefundReason", Type: "uint32"}, }) - transferOutRefundPackageArgs = abi.Arguments{ - {Type: transferOutRefundPackageType}, + TransferOutRefundPackageArgs = abi.Arguments{ + {Type: TransferOutRefundPackageType}, } ) func (pkg *TransferOutRefundPackage) Serialize() ([]byte, error) { - return transferOutRefundPackageArgs.Pack(&TransferOutRefundPackageStruct{ + if pkg.RefundAmount.Cmp(big.NewInt(0)) < 0 { + return nil, errors.Wrapf(ErrInvalidPackage, "refund amount should not be negative") + } + + return TransferOutRefundPackageArgs.Pack(&TransferOutRefundPackageStruct{ SafeBigInt(pkg.RefundAmount), - common.BytesToAddress(pkg.RefundAddr), + common.BytesToAddress(pkg.RefundAddress), pkg.RefundReason, }) } func DeserializeTransferOutRefundPackage(serializedPackage []byte) (*TransferOutRefundPackage, error) { - unpacked, err := transferOutRefundPackageArgs.Unpack(serializedPackage) + unpacked, err := TransferOutRefundPackageArgs.Unpack(serializedPackage) if err != nil { return nil, errors.Wrapf(ErrInvalidPackage, "deserialize transfer out refund package failed") } @@ -124,7 +128,7 @@ func DeserializeTransferOutRefundPackage(serializedPackage []byte) (*TransferOut tp := TransferOutRefundPackage{ pkgStruct.RefundAmount, - pkgStruct.RefundAddr.Bytes(), + pkgStruct.RefundAddress.Bytes(), pkgStruct.RefundReason, } return &tp, nil @@ -143,19 +147,19 @@ type TransferInSynPackageStruct struct { } var ( - transferInSynPackageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{ + TransferInSynPackageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{ {Name: "Amount", Type: "uint256"}, {Name: "ReceiverAddress", Type: "address"}, {Name: "RefundAddress", Type: "address"}, }) - transferInSynPackageArgs = abi.Arguments{ - {Type: transferInSynPackageType}, + TransferInSynPackageArgs = abi.Arguments{ + {Type: TransferInSynPackageType}, } ) func (pkg *TransferInSynPackage) Serialize() ([]byte, error) { - return transferInSynPackageArgs.Pack(&TransferInSynPackageStruct{ + return TransferInSynPackageArgs.Pack(&TransferInSynPackageStruct{ SafeBigInt(pkg.Amount), common.BytesToAddress(pkg.ReceiverAddress), common.BytesToAddress(pkg.RefundAddress), @@ -163,7 +167,7 @@ func (pkg *TransferInSynPackage) Serialize() ([]byte, error) { } func DeserializeTransferInSynPackage(serializedPackage []byte) (*TransferInSynPackage, error) { - unpacked, err := transferInSynPackageArgs.Unpack(serializedPackage) + unpacked, err := TransferInSynPackageArgs.Unpack(serializedPackage) if err != nil { return nil, errors.Wrapf(ErrInvalidPackage, "deserialize transfer in sync package failed") } @@ -195,19 +199,23 @@ type TransferInRefundPackageStruct struct { } var ( - transferInRefundPackageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{ + TransferInRefundPackageType, _ = abi.NewType("tuple", "", []abi.ArgumentMarshaling{ {Name: "RefundAmount", Type: "uint256"}, - {Name: "RefundAddr", Type: "address"}, + {Name: "RefundAddress", Type: "address"}, {Name: "RefundReason", Type: "uint32"}, }) - transferInRefundPackageArgs = abi.Arguments{ - {Type: transferInRefundPackageType}, + TransferInRefundPackageArgs = abi.Arguments{ + {Type: TransferInRefundPackageType}, } ) func (pkg *TransferInRefundPackage) Serialize() ([]byte, error) { - return transferInRefundPackageArgs.Pack(&TransferInRefundPackageStruct{ + if pkg.RefundAmount.Cmp(big.NewInt(0)) < 0 { + return nil, errors.Wrapf(ErrInvalidPackage, "refund amount should not be negative") + } + + return TransferInRefundPackageArgs.Pack(&TransferInRefundPackageStruct{ SafeBigInt(pkg.RefundAmount), common.BytesToAddress(pkg.RefundAddress), pkg.RefundReason, diff --git a/x/challenge/abci_test.go b/x/challenge/abci_test.go new file mode 100644 index 000000000..4d1386dfa --- /dev/null +++ b/x/challenge/abci_test.go @@ -0,0 +1,224 @@ +package challenge_test + +import ( + "testing" + + "cosmossdk.io/math" + "github.com/cosmos/cosmos-sdk/baseapp" + "github.com/cosmos/cosmos-sdk/codec" + storetypes "github.com/cosmos/cosmos-sdk/store/types" + "github.com/cosmos/cosmos-sdk/testutil" + sdk "github.com/cosmos/cosmos-sdk/types" + moduletestutil "github.com/cosmos/cosmos-sdk/types/module/testutil" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/suite" + + "github.com/bnb-chain/greenfield/x/challenge" + "github.com/bnb-chain/greenfield/x/challenge/keeper" + "github.com/bnb-chain/greenfield/x/challenge/types" + sptypes "github.com/bnb-chain/greenfield/x/sp/types" + storagetypes "github.com/bnb-chain/greenfield/x/storage/types" + virtualgrouptypes "github.com/bnb-chain/greenfield/x/virtualgroup/types" +) + +type TestSuite struct { + suite.Suite + + cdc codec.Codec + challengeKeeper *keeper.Keeper + + bankKeeper *types.MockBankKeeper + storageKeeper *types.MockStorageKeeper + spKeeper *types.MockSpKeeper + stakingKeeper *types.MockStakingKeeper + paymentKeeper *types.MockPaymentKeeper + + ctx sdk.Context + queryClient types.QueryClient + msgServer types.MsgServer +} + +func (s *TestSuite) SetupTest() { + encCfg := moduletestutil.MakeTestEncodingConfig(challenge.AppModuleBasic{}) + key := storetypes.NewKVStoreKey(types.StoreKey) + testCtx := testutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + + // set mock randao mix + randaoMix := sdk.Keccak256([]byte{1}) + randaoMix = append(randaoMix, sdk.Keccak256([]byte{2})...) + header := testCtx.Ctx.BlockHeader() + header.RandaoMix = randaoMix + testCtx = testutil.TestContext{ + Ctx: sdk.NewContext(testCtx.CMS, header, false, nil, testCtx.Ctx.Logger()), + DB: testCtx.DB, + CMS: testCtx.CMS, + } + + s.ctx = testCtx.Ctx + + ctrl := gomock.NewController(s.T()) + + bankKeeper := types.NewMockBankKeeper(ctrl) + storageKeeper := types.NewMockStorageKeeper(ctrl) + spKeeper := types.NewMockSpKeeper(ctrl) + stakingKeeper := types.NewMockStakingKeeper(ctrl) + paymentKeeper := types.NewMockPaymentKeeper(ctrl) + + s.challengeKeeper = keeper.NewKeeper( + encCfg.Codec, + key, + key, + bankKeeper, + storageKeeper, + spKeeper, + stakingKeeper, + paymentKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + ) + + s.cdc = encCfg.Codec + s.bankKeeper = bankKeeper + s.storageKeeper = storageKeeper + s.spKeeper = spKeeper + s.stakingKeeper = stakingKeeper + s.paymentKeeper = paymentKeeper + + err := s.challengeKeeper.SetParams(s.ctx, types.DefaultParams()) + s.Require().NoError(err) + + queryHelper := baseapp.NewQueryServerTestHelper(testCtx.Ctx, encCfg.InterfaceRegistry) + types.RegisterQueryServer(queryHelper, s.challengeKeeper) + + s.queryClient = types.NewQueryClient(queryHelper) + s.msgServer = keeper.NewMsgServerImpl(*s.challengeKeeper) +} + +func TestTestSuite(t *testing.T) { + suite.Run(t, new(TestSuite)) +} + +func (s *TestSuite) TestBeginBlocker_RemoveExpiredChallenge() { + s.challengeKeeper.SaveChallenge(s.ctx, types.Challenge{ + Id: 100, + ExpiredHeight: 100, + }) + s.challengeKeeper.SaveChallenge(s.ctx, types.Challenge{ + Id: 200, + ExpiredHeight: 300, + }) + + s.ctx = s.ctx.WithBlockHeight(101) + challenge.BeginBlocker(s.ctx, *s.challengeKeeper) + s.Require().False(s.challengeKeeper.ExistsChallenge(s.ctx, 100)) + s.Require().True(s.challengeKeeper.ExistsChallenge(s.ctx, 200)) +} + +func (s *TestSuite) TestBeginBlocker_RemoveSlash() { + s.challengeKeeper.SaveSlash(s.ctx, types.Slash{ + SpId: 100, + ObjectId: sdk.NewUint(100), + Height: 100, + }) + s.challengeKeeper.SaveSlash(s.ctx, types.Slash{ + SpId: 200, + ObjectId: sdk.NewUint(200), + Height: 200, + }) + + params := s.challengeKeeper.GetParams(s.ctx) + params.SlashCoolingOffPeriod = 10 + _ = s.challengeKeeper.SetParams(s.ctx, params) + + s.ctx = s.ctx.WithBlockHeight(101) + challenge.BeginBlocker(s.ctx, *s.challengeKeeper) + s.Require().True(s.challengeKeeper.ExistsSlash(s.ctx, 100, sdk.NewUint(100))) + s.Require().True(s.challengeKeeper.ExistsSlash(s.ctx, 200, sdk.NewUint(200))) + + s.ctx = s.ctx.WithBlockHeight(111) + challenge.BeginBlocker(s.ctx, *s.challengeKeeper) + s.Require().False(s.challengeKeeper.ExistsSlash(s.ctx, 100, sdk.NewUint(100))) + s.Require().True(s.challengeKeeper.ExistsSlash(s.ctx, 200, sdk.NewUint(200))) + + s.ctx = s.ctx.WithBlockHeight(211) + challenge.BeginBlocker(s.ctx, *s.challengeKeeper) + s.Require().False(s.challengeKeeper.ExistsSlash(s.ctx, 100, sdk.NewUint(100))) + s.Require().False(s.challengeKeeper.ExistsSlash(s.ctx, 200, sdk.NewUint(200))) +} + +func (s *TestSuite) TestBeginBlocker_RemoveSpSlashAmount() { + s.challengeKeeper.SetSpSlashAmount(s.ctx, 100, sdk.NewInt(100)) + s.challengeKeeper.SetSpSlashAmount(s.ctx, 200, sdk.NewInt(200)) + + params := s.challengeKeeper.GetParams(s.ctx) + params.SpSlashCountingWindow = 10 + _ = s.challengeKeeper.SetParams(s.ctx, params) + + s.ctx = s.ctx.WithBlockHeight(101) + challenge.BeginBlocker(s.ctx, *s.challengeKeeper) + s.Require().True(s.challengeKeeper.GetSpSlashAmount(s.ctx, 100).Int64() == 100) + s.Require().True(s.challengeKeeper.GetSpSlashAmount(s.ctx, 200).Int64() == 200) + + s.ctx = s.ctx.WithBlockHeight(100) + challenge.BeginBlocker(s.ctx, *s.challengeKeeper) + s.Require().False(s.challengeKeeper.GetSpSlashAmount(s.ctx, 100).Int64() == 100) + s.Require().False(s.challengeKeeper.GetSpSlashAmount(s.ctx, 200).Int64() == 200) +} + +func (s *TestSuite) TestEndBlocker_NoRandomChallenge() { + preChallengeId := s.challengeKeeper.GetChallengeId(s.ctx) + + params := s.challengeKeeper.GetParams(s.ctx) + params.ChallengeCountPerBlock = 0 + _ = s.challengeKeeper.SetParams(s.ctx, params) + + challenge.EndBlocker(s.ctx, *s.challengeKeeper) + afterChallengeId := s.challengeKeeper.GetChallengeId(s.ctx) + s.Require().True(preChallengeId == afterChallengeId) +} + +func (s *TestSuite) TestEndBlocker_ObjectNotExists() { + s.storageKeeper.EXPECT().GetObjectInfoCount(gomock.Any()).Return(sdk.NewUint(0)) + + preChallengeId := s.challengeKeeper.GetChallengeId(s.ctx) + challenge.EndBlocker(s.ctx, *s.challengeKeeper) + afterChallengeId := s.challengeKeeper.GetChallengeId(s.ctx) + s.Require().True(preChallengeId == afterChallengeId) +} + +func (s *TestSuite) TestEndBlocker_SuccessRandomChallenge() { + s.storageKeeper.EXPECT().GetObjectInfoCount(gomock.Any()).Return(sdk.NewUint(100)) + s.storageKeeper.EXPECT().MaxSegmentSize(gomock.Any()).Return(uint64(10000)).AnyTimes() + + existObject := &storagetypes.ObjectInfo{ + Id: math.NewUint(64), + BucketName: "bucketname", + ObjectName: "objectname", + ObjectStatus: storagetypes.OBJECT_STATUS_SEALED, + PayloadSize: 500} + s.storageKeeper.EXPECT().GetObjectInfoById(gomock.Any(), gomock.Eq(existObject.Id)). + Return(existObject, true).AnyTimes() + + existBucket := &storagetypes.BucketInfo{ + BucketName: existObject.BucketName, + Id: math.NewUint(10), + } + s.storageKeeper.EXPECT().GetBucketInfo(gomock.Any(), gomock.Eq(existBucket.BucketName)). + Return(existBucket, true).AnyTimes() + + gvg := &virtualgrouptypes.GlobalVirtualGroup{PrimarySpId: 100, SecondarySpIds: []uint32{ + 1, + }} + s.storageKeeper.EXPECT().GetObjectGVG(gomock.Any(), gomock.Any(), gomock.Any()). + Return(gvg, true).AnyTimes() + + sp := &sptypes.StorageProvider{Id: 1, Status: sptypes.STATUS_IN_SERVICE} + s.spKeeper.EXPECT().GetStorageProvider(gomock.Any(), gomock.Any()). + Return(sp, true).AnyTimes() + + preChallengeId := s.challengeKeeper.GetChallengeId(s.ctx) + challenge.EndBlocker(s.ctx, *s.challengeKeeper) + afterChallengeId := s.challengeKeeper.GetChallengeId(s.ctx) + s.Require().True(preChallengeId == afterChallengeId-1) +} diff --git a/x/challenge/keeper/grpc_query_test.go b/x/challenge/keeper/grpc_query_test.go index a5cc307fd..03b6aa0db 100644 --- a/x/challenge/keeper/grpc_query_test.go +++ b/x/challenge/keeper/grpc_query_test.go @@ -29,6 +29,33 @@ func TestParamsQuery(t *testing.T) { require.Equal(t, &types.QueryParamsResponse{Params: params}, response) } +func TestAttestedChallengeQuery(t *testing.T) { + keeper, ctx := makeKeeper(t) + err := keeper.SetParams(ctx, types.DefaultParams()) + require.NoError(t, err) + c100 := &types.AttestedChallenge{Id: 100, Result: types.CHALLENGE_SUCCEED} + c200 := &types.AttestedChallenge{Id: 200, Result: types.CHALLENGE_FAILED} + keeper.AppendAttestedChallenge(ctx, c100) + keeper.AppendAttestedChallenge(ctx, c200) + + response, err := keeper.AttestedChallenge(ctx, &types.QueryAttestedChallengeRequest{ + ChallengeId: 100, + }) + require.NoError(t, err) + require.Equal(t, &types.QueryAttestedChallengeResponse{Challenge: c100}, response) + + response, err = keeper.AttestedChallenge(ctx, &types.QueryAttestedChallengeRequest{ + ChallengeId: 200, + }) + require.NoError(t, err) + require.Equal(t, &types.QueryAttestedChallengeResponse{Challenge: c200}, response) + + _, err = keeper.AttestedChallenge(ctx, &types.QueryAttestedChallengeRequest{ + ChallengeId: 300, + }) + require.Error(t, err) +} + func TestLatestAttestedChallengesQuery(t *testing.T) { keeper, ctx := makeKeeper(t) err := keeper.SetParams(ctx, types.DefaultParams()) diff --git a/x/challenge/keeper/msg_server_attest_test.go b/x/challenge/keeper/msg_server_attest_test.go index 176c140dc..cfd03edeb 100644 --- a/x/challenge/keeper/msg_server_attest_test.go +++ b/x/challenge/keeper/msg_server_attest_test.go @@ -59,8 +59,8 @@ func (s *TestSuite) TestAttest_Invalid() { name: "unknown challenge", msg: types.MsgAttest{ ChallengeId: 1, - Submitter: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Submitter: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), }, err: types.ErrInvalidChallengeId, }, @@ -68,8 +68,8 @@ func (s *TestSuite) TestAttest_Invalid() { name: "not valid submitter", msg: types.MsgAttest{ ChallengeId: 100, - Submitter: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Submitter: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), }, err: types.ErrNotChallenger, }, @@ -78,7 +78,7 @@ func (s *TestSuite) TestAttest_Invalid() { msg: types.MsgAttest{ ChallengeId: 100, Submitter: validSubmitter.String(), - SpOperatorAddress: sample.AccAddress(), + SpOperatorAddress: sample.RandAccAddressHex(), ObjectId: math.NewUint(10), VoteValidatorSet: []uint64{}, VoteAggSignature: []byte{}, @@ -90,7 +90,7 @@ func (s *TestSuite) TestAttest_Invalid() { msg: types.MsgAttest{ ChallengeId: 100, Submitter: validSubmitter.String(), - SpOperatorAddress: sample.AccAddress(), + SpOperatorAddress: sample.RandAccAddressHex(), ObjectId: math.NewUint(10), VoteValidatorSet: []uint64{1}, VoteAggSignature: []byte{}, @@ -195,9 +195,13 @@ func (s *TestSuite) TestAttest_Heartbeat() { func (s *TestSuite) TestAttest_Normal() { // prepare challenge - challengeId := uint64(99) + challenge1Id := uint64(99) s.challengeKeeper.SaveChallenge(s.ctx, types.Challenge{ - Id: challengeId, + Id: challenge1Id, + }) + challenge2Id := uint64(100) + s.challengeKeeper.SaveChallenge(s.ctx, types.Challenge{ + Id: challenge2Id, }) validSubmitter := sample.RandAccAddress() @@ -220,14 +224,23 @@ func (s *TestSuite) TestAttest_Normal() { s.storageKeeper.EXPECT().GetBucketInfo(gomock.Any(), gomock.Eq(existBucket.BucketName)). Return(existBucket, true).AnyTimes() - existObject := &storagetypes.ObjectInfo{ + existObject1 := &storagetypes.ObjectInfo{ Id: math.NewUint(10), - ObjectName: "existobject", + ObjectName: "existobject1", BucketName: existBucket.BucketName, ObjectStatus: storagetypes.OBJECT_STATUS_SEALED, PayloadSize: 500} s.storageKeeper.EXPECT().GetObjectInfoById(gomock.Any(), gomock.Eq(math.NewUint(10))). - Return(existObject, true).AnyTimes() + Return(existObject1, true).AnyTimes() + + existObject2 := &storagetypes.ObjectInfo{ + Id: math.NewUint(100), + ObjectName: "existobject2", + BucketName: existBucket.BucketName, + ObjectStatus: storagetypes.OBJECT_STATUS_SEALED, + PayloadSize: 500} + s.storageKeeper.EXPECT().GetObjectInfoById(gomock.Any(), gomock.Eq(math.NewUint(100))). + Return(existObject2, true).AnyTimes() spOperatorAcc := sample.RandAccAddress() sp := &sptypes.StorageProvider{Id: 1, OperatorAddress: spOperatorAcc.String()} @@ -238,30 +251,78 @@ func (s *TestSuite) TestAttest_Normal() { s.spKeeper.EXPECT().GetStorageProviderByOperatorAddr(gomock.Any(), gomock.Any()). Return(sp, true).AnyTimes() s.storageKeeper.EXPECT().MustGetPrimarySPForBucket(gomock.Any(), gomock.Any()).Return(sp).AnyTimes() - attestMsg := &types.MsgAttest{ + + // success attestation + attestMsg1 := &types.MsgAttest{ Submitter: validSubmitter.String(), - ChallengeId: challengeId, + ChallengeId: challenge1Id, ObjectId: math.NewUint(10), SpOperatorAddress: spOperatorAcc.String(), VoteResult: types.CHALLENGE_SUCCEED, ChallengerAddress: "", VoteValidatorSet: []uint64{1}, } - toSign := attestMsg.GetBlsSignBytes(s.ctx.ChainID()) + toSign1 := attestMsg1.GetBlsSignBytes(s.ctx.ChainID()) + voteAggSignature1 := blsKey.Sign(toSign1[:]) + attestMsg1.VoteAggSignature = voteAggSignature1.Marshal() + _, err := s.msgServer.Attest(s.ctx, attestMsg1) + require.NoError(s.T(), err) - voteAggSignature := blsKey.Sign(toSign[:]) - attestMsg.VoteAggSignature = voteAggSignature.Marshal() + attestedChallenges := s.challengeKeeper.GetAttestedChallenges(s.ctx) + attest1Found := false + for _, c := range attestedChallenges { + if c.Id == challenge1Id { + attest1Found = true + } + } + s.Require().True(attest1Found) + s.Require().True(s.challengeKeeper.ExistsSlash(s.ctx, sp.Id, attestMsg1.ObjectId)) - _, err := s.msgServer.Attest(s.ctx, attestMsg) + // success attestation even exceed the max slash amount + params := s.challengeKeeper.GetParams(s.ctx) + params.SpSlashMaxAmount = math.NewInt(1) + _ = s.challengeKeeper.SetParams(s.ctx, params) + + attestMsg2 := &types.MsgAttest{ + Submitter: validSubmitter.String(), + ChallengeId: challenge2Id, + ObjectId: math.NewUint(100), + SpOperatorAddress: spOperatorAcc.String(), + VoteResult: types.CHALLENGE_SUCCEED, + ChallengerAddress: sample.RandAccAddress().String(), + VoteValidatorSet: []uint64{1}, + } + toSign2 := attestMsg2.GetBlsSignBytes(s.ctx.ChainID()) + voteAggSignature2 := blsKey.Sign(toSign2[:]) + attestMsg2.VoteAggSignature = voteAggSignature2.Marshal() + _, err = s.msgServer.Attest(s.ctx, attestMsg2) require.NoError(s.T(), err) - attestedChallenges := s.challengeKeeper.GetAttestedChallenges(s.ctx) - found := false + attestedChallenges = s.challengeKeeper.GetAttestedChallenges(s.ctx) + attest2Found := false for _, c := range attestedChallenges { - if c.Id == challengeId { - found = true + if c.Id == challenge1Id { + attest2Found = true } } - s.Require().True(found) - s.Require().True(s.challengeKeeper.ExistsSlash(s.ctx, sp.Id, attestMsg.ObjectId)) + s.Require().True(attest1Found) + s.Require().True(s.challengeKeeper.ExistsSlash(s.ctx, sp.Id, attestMsg1.ObjectId)) + s.Require().True(attest2Found) + s.Require().True(s.challengeKeeper.ExistsSlash(s.ctx, sp.Id, attestMsg2.ObjectId)) + + // the sp and the object had been slashed + attestMsg3 := &types.MsgAttest{ + Submitter: validSubmitter.String(), + ChallengeId: challenge2Id, + ObjectId: math.NewUint(100), + SpOperatorAddress: spOperatorAcc.String(), + VoteResult: types.CHALLENGE_SUCCEED, + ChallengerAddress: sample.RandAccAddress().String(), + VoteValidatorSet: []uint64{1}, + } + toSign3 := attestMsg3.GetBlsSignBytes(s.ctx.ChainID()) + voteAggSignature3 := blsKey.Sign(toSign3[:]) + attestMsg3.VoteAggSignature = voteAggSignature3.Marshal() + _, err = s.msgServer.Attest(s.ctx, attestMsg3) + require.Error(s.T(), err) } diff --git a/x/challenge/keeper/msg_server_submit_test.go b/x/challenge/keeper/msg_server_submit_test.go index 6b9bebaa9..4b6146337 100644 --- a/x/challenge/keeper/msg_server_submit_test.go +++ b/x/challenge/keeper/msg_server_submit_test.go @@ -19,7 +19,11 @@ func (s *TestSuite) TestSubmit() { existSp := &sptypes.StorageProvider{Status: sptypes.STATUS_IN_SERVICE, Id: 100, OperatorAddress: existSpAddr.String()} s.spKeeper.EXPECT().GetStorageProvider(gomock.Any(), gomock.Eq(existSp.Id)). Return(existSp, true).AnyTimes() - s.storageKeeper.EXPECT().MustGetPrimarySPForBucket(gomock.Any(), gomock.Any()).Return(existSp).AnyTimes() + + jailedSpAddr := sample.RandAccAddress() + jailedSp := &sptypes.StorageProvider{Status: sptypes.STATUS_IN_JAILED, Id: 200, OperatorAddress: jailedSpAddr.String()} + s.spKeeper.EXPECT().GetStorageProvider(gomock.Any(), gomock.Eq(jailedSp.Id)). + Return(jailedSp, true).AnyTimes() existBucketName, existObjectName := "existbucket", "existobject" existObject := &storagetypes.ObjectInfo{ @@ -30,41 +34,88 @@ func (s *TestSuite) TestSubmit() { PayloadSize: 500} s.storageKeeper.EXPECT().GetObjectInfo(gomock.Any(), gomock.Eq(existBucketName), gomock.Eq(existObjectName)). Return(existObject, true).AnyTimes() - s.storageKeeper.EXPECT().GetObjectInfo(gomock.Any(), gomock.Any(), gomock.Any()). - Return(nil, false).AnyTimes() existBucket := &storagetypes.BucketInfo{ BucketName: existBucketName, } s.storageKeeper.EXPECT().GetBucketInfo(gomock.Any(), gomock.Eq(existBucketName)). Return(existBucket, true).AnyTimes() + s.storageKeeper.EXPECT().MustGetPrimarySPForBucket(gomock.Any(), gomock.Eq(existBucket)).Return(existSp).AnyTimes() + + jailedBucketName, jailedObjectName := "jailedbucket", "jailedobject" + jailedObject := &storagetypes.ObjectInfo{ + Id: math.NewUint(10), + BucketName: jailedBucketName, + ObjectName: jailedObjectName, + ObjectStatus: storagetypes.OBJECT_STATUS_SEALED, + PayloadSize: 500} + s.storageKeeper.EXPECT().GetObjectInfo(gomock.Any(), gomock.Eq(jailedBucketName), gomock.Eq(jailedObjectName)). + Return(jailedObject, true).AnyTimes() + + jailedBucket := &storagetypes.BucketInfo{ + BucketName: jailedBucketName, + } + s.storageKeeper.EXPECT().GetBucketInfo(gomock.Any(), gomock.Eq(jailedBucketName)). + Return(jailedBucket, true).AnyTimes() + s.storageKeeper.EXPECT().MustGetPrimarySPForBucket(gomock.Any(), gomock.Eq(jailedBucket)).Return(jailedSp).AnyTimes() + + s.storageKeeper.EXPECT().GetObjectInfo(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, false).AnyTimes() s.storageKeeper.EXPECT().GetBucketInfo(gomock.Any(), gomock.Any()). Return(nil, false).AnyTimes() s.storageKeeper.EXPECT().MaxSegmentSize(gomock.Any()).Return(uint64(10000)).AnyTimes() - gvg := &virtualgrouptypes.GlobalVirtualGroup{PrimarySpId: 100} + gvg := &virtualgrouptypes.GlobalVirtualGroup{PrimarySpId: 100, SecondarySpIds: []uint32{ + 1, + }} s.storageKeeper.EXPECT().GetObjectGVG(gomock.Any(), gomock.Any(), gomock.Any()). Return(gvg, true).AnyTimes() + secondarySpAddr := sample.RandAccAddress() + secondarySp := &sptypes.StorageProvider{Status: sptypes.STATUS_IN_SERVICE, Id: 1, OperatorAddress: secondarySpAddr.String()} + s.spKeeper.EXPECT().GetStorageProvider(gomock.Any(), gomock.Eq(secondarySp.Id)). + Return(secondarySp, true).AnyTimes() + tests := []struct { name string msg types.MsgSubmit err error }{ + { + name: "incorrect sp status", + msg: types.MsgSubmit{ + Challenger: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), + BucketName: jailedBucketName, + ObjectName: jailedObjectName, + }, + err: types.ErrInvalidSpStatus, + }, { name: "not store on the sp", msg: types.MsgSubmit{ - Challenger: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), BucketName: existBucketName, ObjectName: existObjectName, }, err: types.ErrNotStoredOnSp, - }, { + }, + { + name: "unknown bucket", + msg: types.MsgSubmit{ + Challenger: sample.RandAccAddressHex(), + SpOperatorAddress: existSpAddr.String(), + BucketName: "unknownbucket", + ObjectName: "nonexistobject", + }, + err: types.ErrUnknownBucketObject, + }, + { name: "unknown object", msg: types.MsgSubmit{ - Challenger: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), SpOperatorAddress: existSpAddr.String(), BucketName: existBucketName, ObjectName: "nonexistobject", @@ -74,7 +125,7 @@ func (s *TestSuite) TestSubmit() { { name: "invalid segment index", msg: types.MsgSubmit{ - Challenger: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), SpOperatorAddress: existSpAddr.String(), BucketName: existBucketName, ObjectName: existObjectName, @@ -85,7 +136,7 @@ func (s *TestSuite) TestSubmit() { { name: "success with specific index", msg: types.MsgSubmit{ - Challenger: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), SpOperatorAddress: existSpAddr.String(), BucketName: existBucketName, ObjectName: existObjectName, @@ -94,12 +145,21 @@ func (s *TestSuite) TestSubmit() { }, { name: "success with random index", msg: types.MsgSubmit{ - Challenger: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), SpOperatorAddress: existSpAddr.String(), BucketName: existBucketName, ObjectName: existObjectName, RandomIndex: true, }, + }, { + name: "success with secondary sp", + msg: types.MsgSubmit{ + Challenger: sample.RandAccAddressHex(), + SpOperatorAddress: secondarySpAddr.String(), + BucketName: existBucketName, + ObjectName: existObjectName, + RandomIndex: true, + }, }, } for _, tt := range tests { @@ -114,8 +174,8 @@ func (s *TestSuite) TestSubmit() { } // verify storage - s.Require().Equal(uint64(2), s.challengeKeeper.GetChallengeCountCurrentBlock(s.ctx)) - s.Require().Equal(uint64(2), s.challengeKeeper.GetChallengeId(s.ctx)) + s.Require().Equal(uint64(3), s.challengeKeeper.GetChallengeCountCurrentBlock(s.ctx)) + s.Require().Equal(uint64(3), s.challengeKeeper.GetChallengeId(s.ctx)) // create slash s.challengeKeeper.SaveSlash(s.ctx, types.Slash{ @@ -132,7 +192,7 @@ func (s *TestSuite) TestSubmit() { { name: "failed due to recent slash", msg: types.MsgSubmit{ - Challenger: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), SpOperatorAddress: existSpAddr.String(), BucketName: existBucketName, ObjectName: existObjectName, diff --git a/x/challenge/keeper/msg_server_update_params_test.go b/x/challenge/keeper/msg_server_update_params_test.go index d20b43b70..d1390cc84 100644 --- a/x/challenge/keeper/msg_server_update_params_test.go +++ b/x/challenge/keeper/msg_server_update_params_test.go @@ -21,7 +21,7 @@ func (s *TestSuite) TestUpdateParams() { { name: "invalid authority", msg: types.MsgUpdateParams{ - Authority: sample.AccAddress(), + Authority: sample.RandAccAddressHex(), }, err: true, }, { diff --git a/x/challenge/keeper/slash_test.go b/x/challenge/keeper/slash_test.go index 036a26338..b9fe882b1 100644 --- a/x/challenge/keeper/slash_test.go +++ b/x/challenge/keeper/slash_test.go @@ -22,7 +22,7 @@ func createSlash(keeper *keeper.Keeper, ctx sdk.Context, n int) []types.Slash { return items } -func TestRecentSlashRemove(t *testing.T) { +func TestRemoveRecentSlash(t *testing.T) { keeper, ctx := makeKeeper(t) items := createSlash(keeper, ctx, 10) for _, item := range items { @@ -31,3 +31,12 @@ func TestRecentSlashRemove(t *testing.T) { require.False(t, found) } } + +func TestRemoveSpSlashAmount(t *testing.T) { + keeper, ctx := makeKeeper(t) + keeper.SetSpSlashAmount(ctx, 1, sdk.NewInt(100)) + keeper.SetSpSlashAmount(ctx, 2, sdk.NewInt(200)) + keeper.ClearSpSlashAmount(ctx) + require.True(t, keeper.GetSpSlashAmount(ctx, 1).Int64() == 0) + require.True(t, keeper.GetSpSlashAmount(ctx, 2).Int64() == 0) +} diff --git a/x/challenge/module_simulation.go b/x/challenge/module_simulation.go index 4c1c2824c..971952cfd 100644 --- a/x/challenge/module_simulation.go +++ b/x/challenge/module_simulation.go @@ -16,7 +16,7 @@ import ( // avoid unused import issue var ( - _ = sample.AccAddress + _ = sample.RandAccAddressHex _ = challengesimulation.FindAccount _ = simulation.MsgEntryKind _ = baseapp.Paramspace diff --git a/x/challenge/types/message_attest_test.go b/x/challenge/types/message_attest_test.go index 04a2313e4..3b5cbcf95 100644 --- a/x/challenge/types/message_attest_test.go +++ b/x/challenge/types/message_attest_test.go @@ -25,16 +25,16 @@ func TestMsgAttest_ValidateBasic(t *testing.T) { }, { name: "invalid vote result", msg: MsgAttest{ - Submitter: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Submitter: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), VoteResult: 100, }, err: ErrInvalidVoteResult, }, { name: "invalid vote result", msg: MsgAttest{ - Submitter: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Submitter: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), VoteResult: CHALLENGE_SUCCEED, VoteValidatorSet: make([]uint64, 0), }, @@ -42,8 +42,8 @@ func TestMsgAttest_ValidateBasic(t *testing.T) { }, { name: "invalid vote aggregated signature", msg: MsgAttest{ - Submitter: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Submitter: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), VoteResult: CHALLENGE_SUCCEED, VoteValidatorSet: []uint64{1}, VoteAggSignature: []byte{1, 2, 3}, @@ -52,8 +52,8 @@ func TestMsgAttest_ValidateBasic(t *testing.T) { }, { name: "valid message", msg: MsgAttest{ - Submitter: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Submitter: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), VoteResult: CHALLENGE_SUCCEED, VoteValidatorSet: []uint64{1}, VoteAggSignature: sig[:], diff --git a/x/challenge/types/message_submit_test.go b/x/challenge/types/message_submit_test.go index 7cd18c44d..c3ee86256 100644 --- a/x/challenge/types/message_submit_test.go +++ b/x/challenge/types/message_submit_test.go @@ -25,16 +25,16 @@ func TestMsgSubmit_ValidateBasic(t *testing.T) { }, { name: "invalid bucket name", msg: MsgSubmit{ - Challenger: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), BucketName: "1", }, err: gnfderrors.ErrInvalidBucketName, }, { name: "invalid object name", msg: MsgSubmit{ - Challenger: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), BucketName: "bucket", ObjectName: "", }, @@ -42,8 +42,8 @@ func TestMsgSubmit_ValidateBasic(t *testing.T) { }, { name: "valid message with random index", msg: MsgSubmit{ - Challenger: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), BucketName: "bucket", ObjectName: "object", RandomIndex: true, @@ -52,8 +52,8 @@ func TestMsgSubmit_ValidateBasic(t *testing.T) { }, { name: "valid message with specific index", msg: MsgSubmit{ - Challenger: sample.AccAddress(), - SpOperatorAddress: sample.AccAddress(), + Challenger: sample.RandAccAddressHex(), + SpOperatorAddress: sample.RandAccAddressHex(), BucketName: "bucket", ObjectName: "object", RandomIndex: false, diff --git a/x/challenge/types/message_update_params_test.go b/x/challenge/types/message_update_params_test.go index 8d6e34eab..7f69da570 100644 --- a/x/challenge/types/message_update_params_test.go +++ b/x/challenge/types/message_update_params_test.go @@ -29,14 +29,14 @@ func TestMsgUpdateParams_ValidateBasic(t *testing.T) { }, { name: "invalid params", msg: MsgUpdateParams{ - Authority: sample.AccAddress(), + Authority: sample.RandAccAddressHex(), Params: wrongParams, }, err: ErrInvalidParams, }, { name: "valid authority and params", msg: MsgUpdateParams{ - Authority: sample.AccAddress(), + Authority: sample.RandAccAddressHex(), Params: DefaultParams(), }, }, diff --git a/x/payment/client/cli/query.go b/x/payment/client/cli/query.go index bfffe9292..7c130eb7b 100644 --- a/x/payment/client/cli/query.go +++ b/x/payment/client/cli/query.go @@ -10,7 +10,7 @@ import ( ) // GetQueryCmd returns the cli query commands for this module -func GetQueryCmd(queryRoute string) *cobra.Command { +func GetQueryCmd() *cobra.Command { // Group payment queries under a subcommand cmd := &cobra.Command{ Use: types.ModuleName, diff --git a/x/payment/keeper/auto_resume_record_test.go b/x/payment/keeper/auto_resume_record_test.go new file mode 100644 index 000000000..eb175835d --- /dev/null +++ b/x/payment/keeper/auto_resume_record_test.go @@ -0,0 +1,52 @@ +package keeper_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func TestAutoResumeRecord(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + + addr1 := sample.RandAccAddress() + record1 := &types.AutoResumeRecord{ + Addr: addr1.String(), + Timestamp: 100, + } + + addr2 := sample.RandAccAddress() + record2 := &types.AutoResumeRecord{ + Addr: addr2.String(), + Timestamp: 200, + } + + // set + keeper.SetAutoResumeRecord(ctx, record1) + keeper.SetAutoResumeRecord(ctx, record2) + + // exits + // before the timestamp + exist := keeper.ExistsAutoResumeRecord(ctx, 90, addr1) + require.True(t, !exist) + exist = keeper.ExistsAutoResumeRecord(ctx, 101, addr1) + require.True(t, exist) + + // at any time + exist = keeper.ExistsAutoResumeRecord(ctx, 0, addr1) + require.True(t, exist) + exist = keeper.ExistsAutoResumeRecord(ctx, 0, addr2) + require.True(t, exist) + + // remove + keeper.RemoveAutoResumeRecord(ctx, record1.Timestamp, addr1) + keeper.RemoveAutoResumeRecord(ctx, record2.Timestamp, addr2) + + exist = keeper.ExistsAutoResumeRecord(ctx, 0, addr1) + require.True(t, !exist) + exist = keeper.ExistsAutoResumeRecord(ctx, 0, addr2) + require.True(t, !exist) +} diff --git a/x/payment/keeper/auto_settle_record.go b/x/payment/keeper/auto_settle_record.go index 7c5c5f7c9..afd7ba074 100644 --- a/x/payment/keeper/auto_settle_record.go +++ b/x/payment/keeper/auto_settle_record.go @@ -18,28 +18,6 @@ func (k Keeper) SetAutoSettleRecord(ctx sdk.Context, autoSettleRecord *types.Aut ), b) } -// GetAutoSettleRecord returns a autoSettleRecord from its index -func (k Keeper) GetAutoSettleRecord( - ctx sdk.Context, - timestamp int64, - addr sdk.AccAddress, -) (*types.AutoSettleRecord, bool) { - store := prefix.NewStore(ctx.KVStore(k.storeKey), types.AutoSettleRecordKeyPrefix) - - b := store.Get(types.AutoSettleRecordKey( - timestamp, - addr, - )) - if b == nil { - return nil, false - } - - return &types.AutoSettleRecord{ - Timestamp: timestamp, - Addr: addr.String(), - }, true -} - // RemoveAutoSettleRecord removes a autoSettleRecord from the store func (k Keeper) RemoveAutoSettleRecord( ctx sdk.Context, diff --git a/x/payment/keeper/auto_settle_record_test.go b/x/payment/keeper/auto_settle_record_test.go new file mode 100644 index 000000000..1dfa6e663 --- /dev/null +++ b/x/payment/keeper/auto_settle_record_test.go @@ -0,0 +1,42 @@ +package keeper_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func TestAutoSettleRecord(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + + addr1 := sample.RandAccAddress() + record1 := &types.AutoSettleRecord{ + Addr: addr1.String(), + Timestamp: 100, + } + + addr2 := sample.RandAccAddress() + record2 := &types.AutoSettleRecord{ + Addr: addr2.String(), + Timestamp: 200, + } + + // set + keeper.SetAutoSettleRecord(ctx, record1) + keeper.SetAutoSettleRecord(ctx, record2) + + // update to new time + keeper.UpdateAutoSettleRecord(ctx, addr1, record1.Timestamp, 110) + + // update to remove + keeper.UpdateAutoSettleRecord(ctx, addr2, record2.Timestamp, 0) + + // get all + records := keeper.GetAllAutoSettleRecord(ctx) + require.True(t, len(records) == 1) + require.True(t, records[0].Addr == addr1.String()) + require.True(t, records[0].Timestamp == 110) +} diff --git a/x/payment/keeper/grpc_query_test.go b/x/payment/keeper/grpc_query_test.go new file mode 100644 index 000000000..b7ce3a41f --- /dev/null +++ b/x/payment/keeper/grpc_query_test.go @@ -0,0 +1,306 @@ +package keeper_test + +import ( + "testing" + "time" + + sdkmath "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func TestParamsQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + response, err := keeper.Params(ctx, &types.QueryParamsRequest{}) + require.NoError(t, err) + require.Equal(t, &types.QueryParamsResponse{Params: params}, response) +} + +func TestParamsByTimestampQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + + before := time.Now() + ctx = ctx.WithBlockTime(before) + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + after := time.Unix(before.Unix()+10, 0) + ctx = ctx.WithBlockTime(after) + newReserveTime := uint64(1000000000) + params.VersionedParams.ReserveTime = newReserveTime + err = keeper.SetParams(ctx, params) + require.NoError(t, err) + + response, err := keeper.ParamsByTimestamp(ctx, &types.QueryParamsByTimestampRequest{ + Timestamp: before.Unix(), + }) + require.NoError(t, err) + require.True(t, newReserveTime != response.Params.VersionedParams.ReserveTime) + + response, err = keeper.ParamsByTimestamp(ctx, &types.QueryParamsByTimestampRequest{ + Timestamp: after.Unix(), + }) + require.NoError(t, err) + require.True(t, newReserveTime == response.Params.VersionedParams.ReserveTime) +} + +func TestAutoSettleRecordQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + record := types.AutoSettleRecord{ + Timestamp: 123, + Addr: sample.RandAccAddress().String(), + } + keeper.SetAutoSettleRecord(ctx, &record) + + response, err := keeper.AutoSettleRecordAll(ctx, &types.QueryAllAutoSettleRecordRequest{}) + require.NoError(t, err) + require.Equal(t, record, response.AutoSettleRecord[0]) +} + +func TestDynamicBalanceQuery(t *testing.T) { + keeper, ctx, deepKeepers := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + deepKeepers.AccountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()). + Return(true).AnyTimes() + bankBalance := sdk.NewCoin("BNB", sdkmath.NewInt(1000)) + + deepKeepers.BankKeeper.EXPECT().GetBalance(gomock.Any(), gomock.Any(), gomock.Any()). + Return(bankBalance).AnyTimes() + + record := types.NewStreamRecord(sample.RandAccAddress(), ctx.BlockTime().Unix()) + record.StaticBalance = sdkmath.NewInt(100) + keeper.SetStreamRecord(ctx, record) + + response, err := keeper.DynamicBalance(ctx, &types.QueryDynamicBalanceRequest{Account: record.Account}) + require.NoError(t, err) + require.Equal(t, record.StaticBalance.Add(bankBalance.Amount), response.AvailableBalance) + require.Equal(t, bankBalance.Amount, response.BankBalance) +} + +func TestPaymentAccountAllQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + owner1 := sample.RandAccAddress() + record1 := types.PaymentAccount{ + Owner: owner1.String(), + Addr: sample.RandAccAddress().String(), + } + keeper.SetPaymentAccount(ctx, &record1) + + owner2 := sample.RandAccAddress() + record2 := types.PaymentAccount{ + Owner: owner2.String(), + Addr: sample.RandAccAddress().String(), + } + keeper.SetPaymentAccount(ctx, &record2) + + response, err := keeper.PaymentAccountAll(ctx, &types.QueryAllPaymentAccountRequest{}) + require.NoError(t, err) + require.Equal(t, 2, len(response.PaymentAccount)) +} + +func TestPaymentAccountQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + owner1 := sample.RandAccAddress() + addr1 := sample.RandAccAddress().String() + record1 := types.PaymentAccount{ + Owner: owner1.String(), + Addr: addr1, + } + keeper.SetPaymentAccount(ctx, &record1) + + owner2 := sample.RandAccAddress() + addr2 := sample.RandAccAddress().String() + record2 := types.PaymentAccount{ + Owner: owner2.String(), + Addr: addr2, + } + keeper.SetPaymentAccount(ctx, &record2) + + response, err := keeper.PaymentAccount(ctx, &types.QueryGetPaymentAccountRequest{ + Addr: addr1, + }) + require.NoError(t, err) + require.Equal(t, owner1.String(), response.PaymentAccount.Owner) + + response, err = keeper.PaymentAccount(ctx, &types.QueryGetPaymentAccountRequest{ + Addr: addr2, + }) + require.NoError(t, err) + require.Equal(t, owner2.String(), response.PaymentAccount.Owner) +} + +func TestPaymentAccountCountAllQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + owner1 := sample.RandAccAddress() + record1 := types.PaymentAccountCount{ + Owner: owner1.String(), + Count: 10, + } + keeper.SetPaymentAccountCount(ctx, &record1) + + owner2 := sample.RandAccAddress() + record2 := types.PaymentAccountCount{ + Owner: owner2.String(), + Count: 2, + } + keeper.SetPaymentAccountCount(ctx, &record2) + + response, err := keeper.PaymentAccountCountAll(ctx, &types.QueryAllPaymentAccountCountRequest{}) + require.NoError(t, err) + require.Equal(t, 2, len(response.PaymentAccountCount)) +} + +func TestPaymentAccountCountQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + owner1 := sample.RandAccAddress() + record1 := types.PaymentAccountCount{ + Owner: owner1.String(), + Count: 10, + } + keeper.SetPaymentAccountCount(ctx, &record1) + + owner2 := sample.RandAccAddress() + record2 := types.PaymentAccountCount{ + Owner: owner2.String(), + Count: 2, + } + keeper.SetPaymentAccountCount(ctx, &record2) + + response, err := keeper.PaymentAccountCount(ctx, &types.QueryGetPaymentAccountCountRequest{ + Owner: owner1.String(), + }) + require.NoError(t, err) + require.Equal(t, record1.Count, response.PaymentAccountCount.Count) + + response, err = keeper.PaymentAccountCount(ctx, &types.QueryGetPaymentAccountCountRequest{ + Owner: owner2.String(), + }) + require.NoError(t, err) + require.Equal(t, record2.Count, response.PaymentAccountCount.Count) +} + +func TestPaymentAccountsByOwnerQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + owner1 := sample.RandAccAddress() + record1 := types.PaymentAccountCount{ + Owner: owner1.String(), + Count: 2, + } + keeper.SetPaymentAccountCount(ctx, &record1) + + response, err := keeper.GetPaymentAccountsByOwner(ctx, &types.QueryGetPaymentAccountsByOwnerRequest{ + Owner: owner1.String(), + }) + require.NoError(t, err) + require.Equal(t, int(record1.Count), len(response.PaymentAccounts)) +} + +func TestStreamRecordAllQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + owner1 := sample.RandAccAddress() + record1 := types.NewStreamRecord(owner1, ctx.BlockTime().Unix()) + keeper.SetStreamRecord(ctx, record1) + + owner2 := sample.RandAccAddress() + record2 := types.NewStreamRecord(owner2, ctx.BlockTime().Unix()) + keeper.SetStreamRecord(ctx, record2) + + response, err := keeper.StreamRecordAll(ctx, &types.QueryAllStreamRecordRequest{}) + require.NoError(t, err) + require.Equal(t, 2, len(response.StreamRecord)) +} + +func TestStreamRecordQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + owner1 := sample.RandAccAddress() + record1 := types.NewStreamRecord(owner1, ctx.BlockTime().Unix()) + keeper.SetStreamRecord(ctx, record1) + + owner2 := sample.RandAccAddress() + record2 := types.NewStreamRecord(owner2, ctx.BlockTime().Unix()) + keeper.SetStreamRecord(ctx, record2) + + response, err := keeper.StreamRecord(ctx, &types.QueryGetStreamRecordRequest{ + Account: owner1.String(), + }) + require.NoError(t, err) + require.Equal(t, owner1.String(), response.StreamRecord.Account) + + response, err = keeper.StreamRecord(ctx, &types.QueryGetStreamRecordRequest{ + Account: owner2.String(), + }) + require.NoError(t, err) + require.Equal(t, owner2.String(), response.StreamRecord.Account) +} + +func TestOutFlowQuery(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + params := types.DefaultParams() + err := keeper.SetParams(ctx, params) + require.NoError(t, err) + + owner := sample.RandAccAddress() + record1 := types.OutFlow{ + ToAddress: sample.RandAccAddress().String(), + Rate: sdkmath.Int{}, + Status: types.OUT_FLOW_STATUS_FROZEN, + } + keeper.SetOutFlow(ctx, owner, &record1) + + record2 := types.OutFlow{ + ToAddress: sample.RandAccAddress().String(), + Rate: sdkmath.Int{}, + Status: types.OUT_FLOW_STATUS_ACTIVE, + } + keeper.SetOutFlow(ctx, owner, &record2) + + response, err := keeper.OutFlows(ctx, &types.QueryOutFlowsRequest{ + Account: owner.String(), + }) + require.NoError(t, err) + require.Equal(t, 2, len(response.OutFlows)) +} diff --git a/x/payment/keeper/msg_server_create_payment_account_test.go b/x/payment/keeper/msg_server_create_payment_account_test.go new file mode 100644 index 000000000..5abf93b27 --- /dev/null +++ b/x/payment/keeper/msg_server_create_payment_account_test.go @@ -0,0 +1,35 @@ +package keeper_test + +import ( + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func (s *TestSuite) TestCreatePaymentAccount() { + creator := sample.RandAccAddress() + + // create first one + msg := types.NewMsgCreatePaymentAccount(creator.String()) + _, err := s.msgServer.CreatePaymentAccount(s.ctx, msg) + s.Require().NoError(err) + + record, _ := s.paymentKeeper.GetPaymentAccountCount(s.ctx, creator) + s.Require().True(record.Count == 1) + + // create another one + msg = types.NewMsgCreatePaymentAccount(creator.String()) + _, err = s.msgServer.CreatePaymentAccount(s.ctx, msg) + s.Require().NoError(err) + + record, _ = s.paymentKeeper.GetPaymentAccountCount(s.ctx, creator) + s.Require().True(record.Count == 2) + + // limit the number of payment account + params := s.paymentKeeper.GetParams(s.ctx) + params.PaymentAccountCountLimit = 2 + _ = s.paymentKeeper.SetParams(s.ctx, params) + + msg = types.NewMsgCreatePaymentAccount(creator.String()) + _, err = s.msgServer.CreatePaymentAccount(s.ctx, msg) + s.Require().Error(err) +} diff --git a/x/payment/keeper/msg_server_deposit_test.go b/x/payment/keeper/msg_server_deposit_test.go new file mode 100644 index 000000000..72a42a131 --- /dev/null +++ b/x/payment/keeper/msg_server_deposit_test.go @@ -0,0 +1,72 @@ +package keeper_test + +import ( + sdkmath "cosmossdk.io/math" + "github.com/golang/mock/gomock" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func (s *TestSuite) TestDeposit_ToBankAccount() { + s.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() + s.accountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()). + Return(true).AnyTimes() + + // deposit to self + owner := sample.RandAccAddress() + msg := types.NewMsgDeposit(owner.String(), owner.String(), sdkmath.NewInt(1000)) + _, err := s.msgServer.Deposit(s.ctx, msg) + s.Require().NoError(err) + record, _ := s.paymentKeeper.GetStreamRecord(s.ctx, owner) + s.Require().True(record.StaticBalance.Int64() == msg.Amount.Int64()) + + // deposit to other account + to := sample.RandAccAddress() + msg = types.NewMsgDeposit(owner.String(), to.String(), sdkmath.NewInt(1000)) + _, err = s.msgServer.Deposit(s.ctx, msg) + s.Require().NoError(err) + record, _ = s.paymentKeeper.GetStreamRecord(s.ctx, to) + s.Require().True(record.StaticBalance.Int64() == msg.Amount.Int64()) +} + +func (s *TestSuite) TestDeposit_ToActiveStreamRecord() { + s.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() + s.accountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()). + Return(true).AnyTimes() + + owner := sample.RandAccAddress() + paymentAddr := sample.RandAccAddress() + record := types.NewStreamRecord(paymentAddr, s.ctx.BlockTime().Unix()) + s.paymentKeeper.SetStreamRecord(s.ctx, record) + + // deposit to active stream record + msg := types.NewMsgDeposit(owner.String(), paymentAddr.String(), sdkmath.NewInt(1000)) + _, err := s.msgServer.Deposit(s.ctx, msg) + s.Require().NoError(err) + recordAfter, _ := s.paymentKeeper.GetStreamRecord(s.ctx, paymentAddr) + s.Require().True(recordAfter.StaticBalance.Int64() == msg.Amount.Int64()) +} + +func (s *TestSuite) TestDeposit_ToFrozenStreamRecord() { + s.bankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() + s.accountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()). + Return(true).AnyTimes() + + owner := sample.RandAccAddress() + paymentAddr := sample.RandAccAddress() + record := types.NewStreamRecord(paymentAddr, s.ctx.BlockTime().Unix()) + record.Status = types.STREAM_ACCOUNT_STATUS_FROZEN + record.FrozenNetflowRate = sdkmath.NewInt(-10) + s.paymentKeeper.SetStreamRecord(s.ctx, record) + + // deposit to frozen stream record + msg := types.NewMsgDeposit(owner.String(), paymentAddr.String(), sdkmath.NewInt(1000)) + _, err := s.msgServer.Deposit(s.ctx, msg) + s.Require().NoError(err) + recordAfter, _ := s.paymentKeeper.GetStreamRecord(s.ctx, paymentAddr) + s.Require().True(recordAfter.StaticBalance.Int64() == msg.Amount.Int64()) +} diff --git a/x/payment/keeper/msg_server_disable_refund_test.go b/x/payment/keeper/msg_server_disable_refund_test.go new file mode 100644 index 000000000..172e93aca --- /dev/null +++ b/x/payment/keeper/msg_server_disable_refund_test.go @@ -0,0 +1,39 @@ +package keeper_test + +import ( + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func (s *TestSuite) TestDisableRefund() { + // payment account does not exist + creator1 := sample.RandAccAddress() + msg := types.NewMsgDisableRefund(creator1.String(), sample.RandAccAddress().String()) + _, err := s.msgServer.DisableRefund(s.ctx, msg) + s.Require().Error(err) + + // the message is not from the owner + creator2 := sample.RandAccAddress() + createAccountMsg := types.NewMsgCreatePaymentAccount(creator2.String()) + _, err = s.msgServer.CreatePaymentAccount(s.ctx, createAccountMsg) + s.Require().NoError(err) + paymentAccountAddr := s.paymentKeeper.DerivePaymentAccountAddress(creator2, 0) + record, _ := s.paymentKeeper.GetPaymentAccount(s.ctx, paymentAccountAddr) + s.Require().True(record.Owner == creator2.String()) + + msg = types.NewMsgDisableRefund(creator1.String(), paymentAccountAddr.String()) + _, err = s.msgServer.DisableRefund(s.ctx, msg) + s.Require().Error(err) + + // disable refund success + msg = types.NewMsgDisableRefund(creator2.String(), paymentAccountAddr.String()) + _, err = s.msgServer.DisableRefund(s.ctx, msg) + s.Require().NoError(err) + record, _ = s.paymentKeeper.GetPaymentAccount(s.ctx, paymentAccountAddr) + s.Require().True(record.Refundable == false) + + // cannot disable it again + msg = types.NewMsgDisableRefund(creator2.String(), paymentAccountAddr.String()) + _, err = s.msgServer.DisableRefund(s.ctx, msg) + s.Require().Error(err) +} diff --git a/x/payment/keeper/msg_server_test.go b/x/payment/keeper/msg_server_test.go new file mode 100644 index 000000000..c72134039 --- /dev/null +++ b/x/payment/keeper/msg_server_test.go @@ -0,0 +1,115 @@ +package keeper_test + +import ( + "testing" + + "github.com/cosmos/cosmos-sdk/baseapp" + "github.com/cosmos/cosmos-sdk/codec" + storetypes "github.com/cosmos/cosmos-sdk/store/types" + "github.com/cosmos/cosmos-sdk/testutil" + sdk "github.com/cosmos/cosmos-sdk/types" + moduletestutil "github.com/cosmos/cosmos-sdk/types/module/testutil" + authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" + govtypes "github.com/cosmos/cosmos-sdk/x/gov/types" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/challenge" + "github.com/bnb-chain/greenfield/x/payment/keeper" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +type TestSuite struct { + suite.Suite + + cdc codec.Codec + paymentKeeper *keeper.Keeper + + bankKeeper *types.MockBankKeeper + accountKeeper *types.MockAccountKeeper + spKeeper *types.MockSpKeeper + + ctx sdk.Context + queryClient types.QueryClient + msgServer types.MsgServer +} + +func (s *TestSuite) SetupTest() { + encCfg := moduletestutil.MakeTestEncodingConfig(challenge.AppModuleBasic{}) + key := storetypes.NewKVStoreKey(types.StoreKey) + testCtx := testutil.DefaultContextWithDB(s.T(), key, storetypes.NewTransientStoreKey("transient_test")) + s.ctx = testCtx.Ctx + + ctrl := gomock.NewController(s.T()) + + bankKeeper := types.NewMockBankKeeper(ctrl) + accountKeeper := types.NewMockAccountKeeper(ctrl) + spKeeper := types.NewMockSpKeeper(ctrl) + + s.paymentKeeper = keeper.NewKeeper( + encCfg.Codec, + key, + bankKeeper, + accountKeeper, + spKeeper, + authtypes.NewModuleAddress(govtypes.ModuleName).String(), + ) + + s.cdc = encCfg.Codec + s.bankKeeper = bankKeeper + s.accountKeeper = accountKeeper + s.spKeeper = spKeeper + + err := s.paymentKeeper.SetParams(s.ctx, types.DefaultParams()) + s.Require().NoError(err) + + queryHelper := baseapp.NewQueryServerTestHelper(testCtx.Ctx, encCfg.InterfaceRegistry) + types.RegisterQueryServer(queryHelper, s.paymentKeeper) + + s.queryClient = types.NewQueryClient(queryHelper) + s.msgServer = keeper.NewMsgServerImpl(*s.paymentKeeper) +} + +func TestTestSuite(t *testing.T) { + suite.Run(t, new(TestSuite)) +} + +func (s *TestSuite) TestUpdateParams() { + params := types.DefaultParams() + params.MaxAutoResumeFlowCount = 5 + + tests := []struct { + name string + msg types.MsgUpdateParams + err bool + }{ + { + name: "invalid authority", + msg: types.MsgUpdateParams{ + Authority: sample.AccAddress(), + }, + err: true, + }, { + name: "success", + msg: types.MsgUpdateParams{ + Authority: s.paymentKeeper.GetAuthority(), + Params: params, + }, + }, + } + for _, tt := range tests { + s.T().Run(tt.name, func(t *testing.T) { + _, err := s.msgServer.UpdateParams(s.ctx, &tt.msg) + if tt.err { + require.Error(t, err) + return + } + require.NoError(t, err) + }) + } + + // verify storage + s.Require().Equal(params, s.paymentKeeper.GetParams(s.ctx)) +} diff --git a/x/payment/keeper/msg_server_withdraw_test.go b/x/payment/keeper/msg_server_withdraw_test.go new file mode 100644 index 000000000..8f836ee1f --- /dev/null +++ b/x/payment/keeper/msg_server_withdraw_test.go @@ -0,0 +1,86 @@ +package keeper_test + +import ( + sdkmath "cosmossdk.io/math" + "github.com/golang/mock/gomock" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func (s *TestSuite) TestWithdraw_Fail() { + creator1 := sample.RandAccAddress() + paymentAddr1 := sample.RandAccAddress() + + // stream record not found + msg := types.NewMsgWithdraw(creator1.String(), sample.RandAccAddress().String(), sdkmath.NewInt(100)) + _, err := s.msgServer.Withdraw(s.ctx, msg) + s.Require().Error(err) + + // stream record is frozen + record1 := types.NewStreamRecord(paymentAddr1, s.ctx.BlockTime().Unix()) + record1.Status = types.STREAM_ACCOUNT_STATUS_FROZEN + s.paymentKeeper.SetStreamRecord(s.ctx, record1) + + msg = types.NewMsgWithdraw(creator1.String(), paymentAddr1.String(), sdkmath.NewInt(100)) + _, err = s.msgServer.Withdraw(s.ctx, msg) + s.Require().Error(err) + + record1.Status = types.STREAM_ACCOUNT_STATUS_ACTIVE + s.paymentKeeper.SetStreamRecord(s.ctx, record1) + + // payment account does not exist + msg = types.NewMsgWithdraw(creator1.String(), paymentAddr1.String(), sdkmath.NewInt(100)) + _, err = s.msgServer.Withdraw(s.ctx, msg) + s.Require().Error(err) + + // the message is not from the owner + creator2 := sample.RandAccAddress() + createAccountMsg := types.NewMsgCreatePaymentAccount(creator2.String()) + _, err = s.msgServer.CreatePaymentAccount(s.ctx, createAccountMsg) + s.Require().NoError(err) + paymentAddr2 := s.paymentKeeper.DerivePaymentAccountAddress(creator2, 0) + paymentAccountRecord, _ := s.paymentKeeper.GetPaymentAccount(s.ctx, paymentAddr2) + s.Require().True(paymentAccountRecord.Owner == creator2.String()) + + record2 := types.NewStreamRecord(paymentAddr2, s.ctx.BlockTime().Unix()) + s.paymentKeeper.SetStreamRecord(s.ctx, record2) + + msg = types.NewMsgWithdraw(creator1.String(), paymentAddr2.String(), sdkmath.NewInt(100)) + _, err = s.msgServer.Withdraw(s.ctx, msg) + s.Require().Error(err) + + // cannot withdraw after disable refund + disableRefundMsg := types.NewMsgDisableRefund(creator2.String(), paymentAddr2.String()) + _, err = s.msgServer.DisableRefund(s.ctx, disableRefundMsg) + s.Require().NoError(err) + paymentAccountRecord, _ = s.paymentKeeper.GetPaymentAccount(s.ctx, paymentAddr2) + s.Require().True(paymentAccountRecord.Refundable == false) + + msg = types.NewMsgWithdraw(creator2.String(), paymentAddr2.String(), sdkmath.NewInt(100)) + _, err = s.msgServer.Withdraw(s.ctx, msg) + s.Require().Error(err) +} + +func (s *TestSuite) TestWithdraw_Success() { + s.bankKeeper.EXPECT().SendCoinsFromModuleToAccount(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() + s.accountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()). + Return(true).AnyTimes() + + creator := sample.RandAccAddress() + createAccountMsg := types.NewMsgCreatePaymentAccount(creator.String()) + _, err := s.msgServer.CreatePaymentAccount(s.ctx, createAccountMsg) + s.Require().NoError(err) + paymentAddr := s.paymentKeeper.DerivePaymentAccountAddress(creator, 0) + paymentAccountRecord, _ := s.paymentKeeper.GetPaymentAccount(s.ctx, paymentAddr) + s.Require().True(paymentAccountRecord.Owner == creator.String()) + + record := types.NewStreamRecord(paymentAddr, s.ctx.BlockTime().Unix()) + record.StaticBalance = sdkmath.NewInt(200) + s.paymentKeeper.SetStreamRecord(s.ctx, record) + + msg := types.NewMsgWithdraw(creator.String(), paymentAddr.String(), sdkmath.NewInt(100)) + _, err = s.msgServer.Withdraw(s.ctx, msg) + s.Require().NoError(err) +} diff --git a/x/payment/keeper/payment_account_count.go b/x/payment/keeper/payment_account_count.go index e57f00c92..b00b009c0 100644 --- a/x/payment/keeper/payment_account_count.go +++ b/x/payment/keeper/payment_account_count.go @@ -37,17 +37,6 @@ func (k Keeper) GetPaymentAccountCount( return val, true } -// RemovePaymentAccountCount removes a paymentAccountCount from the store -func (k Keeper) RemovePaymentAccountCount( - ctx sdk.Context, - owner sdk.AccAddress, -) { - store := prefix.NewStore(ctx.KVStore(k.storeKey), types.PaymentAccountCountKeyPrefix) - store.Delete(types.PaymentAccountCountKey( - owner, - )) -} - // GetAllPaymentAccountCount returns all paymentAccountCount func (k Keeper) GetAllPaymentAccountCount(ctx sdk.Context) (list []types.PaymentAccountCount) { store := prefix.NewStore(ctx.KVStore(k.storeKey), types.PaymentAccountCountKeyPrefix) diff --git a/x/payment/keeper/payment_account_count_test.go b/x/payment/keeper/payment_account_count_test.go new file mode 100644 index 000000000..8e11e5eaa --- /dev/null +++ b/x/payment/keeper/payment_account_count_test.go @@ -0,0 +1,46 @@ +package keeper_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func TestPaymentAccountCount(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + + owner1 := sample.RandAccAddress() + paymentCount1 := &types.PaymentAccountCount{ + Owner: owner1.String(), + Count: 1, + } + + owner2 := sample.RandAccAddress() + paymentCount2 := &types.PaymentAccountCount{ + Owner: owner2.String(), + Count: 3, + } + + // set + keeper.SetPaymentAccountCount(ctx, paymentCount1) + keeper.SetPaymentAccountCount(ctx, paymentCount2) + + // get + resp1, _ := keeper.GetPaymentAccountCount(ctx, owner1) + require.True(t, resp1.Owner == owner1.String()) + require.True(t, resp1.Count == paymentCount1.Count) + + resp2, _ := keeper.GetPaymentAccountCount(ctx, owner2) + require.True(t, resp2.Owner == owner2.String()) + require.True(t, resp2.Count == paymentCount2.Count) + + _, found := keeper.GetPaymentAccountCount(ctx, sample.RandAccAddress()) + require.True(t, !found) + + // get all + resp3 := keeper.GetAllPaymentAccountCount(ctx) + require.True(t, len(resp3) == 2) +} diff --git a/x/payment/keeper/payment_account_test.go b/x/payment/keeper/payment_account_test.go new file mode 100644 index 000000000..801881b4e --- /dev/null +++ b/x/payment/keeper/payment_account_test.go @@ -0,0 +1,52 @@ +package keeper_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/payment/types" +) + +func TestPaymentAccount(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + + owner1 := sample.RandAccAddress() + addr1 := sample.RandAccAddress() + paymentAccount1 := &types.PaymentAccount{ + Owner: owner1.String(), + Addr: addr1.String(), + Refundable: true, + } + + owner2 := sample.RandAccAddress() + addr2 := sample.RandAccAddress() + paymentAccount2 := &types.PaymentAccount{ + Owner: owner2.String(), + Addr: addr2.String(), + Refundable: false, + } + + // set + keeper.SetPaymentAccount(ctx, paymentAccount1) + keeper.SetPaymentAccount(ctx, paymentAccount2) + + // get + resp1, _ := keeper.GetPaymentAccount(ctx, addr1) + require.True(t, resp1.Owner == owner1.String()) + require.True(t, resp1.Addr == addr1.String()) + require.True(t, resp1.Refundable == paymentAccount1.Refundable) + + resp2, _ := keeper.GetPaymentAccount(ctx, addr2) + require.True(t, resp2.Owner == owner2.String()) + require.True(t, resp2.Addr == addr2.String()) + require.True(t, resp2.Refundable == paymentAccount2.Refundable) + + _, found := keeper.GetPaymentAccount(ctx, sample.RandAccAddress()) + require.True(t, !found) + + // get all + resp3 := keeper.GetAllPaymentAccount(ctx) + require.True(t, len(resp3) == 2) +} diff --git a/x/payment/keeper/price_test.go b/x/payment/keeper/price_test.go new file mode 100644 index 000000000..15f2ba89d --- /dev/null +++ b/x/payment/keeper/price_test.go @@ -0,0 +1,39 @@ +package keeper_test + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + + "github.com/bnb-chain/greenfield/x/payment/types" + sp "github.com/bnb-chain/greenfield/x/sp/types" +) + +func TestGetStoragePrice(t *testing.T) { + keeper, ctx, depKeepers := makePaymentKeeper(t) + + primaryPrice := sp.SpStoragePrice{ + ReadPrice: sdk.NewDecWithPrec(2, 2), + FreeReadQuota: 0, + StorePrice: sdk.NewDecWithPrec(5, 1), + } + depKeepers.SpKeeper.EXPECT().GetSpStoragePriceByTime(gomock.Any(), gomock.Any(), gomock.Any()). + Return(primaryPrice, nil).AnyTimes() + + secondaryPrice := sp.SecondarySpStorePrice{ + StorePrice: sdk.NewDecWithPrec(2, 1), + } + depKeepers.SpKeeper.EXPECT().GetSecondarySpStorePriceByTime(gomock.Any(), gomock.Any()). + Return(secondaryPrice, nil).AnyTimes() + + resp, err := keeper.GetStoragePrice(ctx, types.StoragePriceParams{ + PrimarySp: 1, + PriceTime: 1, + }) + require.NoError(t, err) + require.True(t, resp.ReadPrice.Equal(primaryPrice.ReadPrice)) + require.True(t, resp.PrimaryStorePrice.Equal(primaryPrice.StorePrice)) + require.True(t, resp.SecondaryStorePrice.Equal(secondaryPrice.StorePrice)) +} diff --git a/x/payment/keeper/storage_fee_charge.go b/x/payment/keeper/storage_fee_charge.go index c5d3239df..6518e82d9 100644 --- a/x/payment/keeper/storage_fee_charge.go +++ b/x/payment/keeper/storage_fee_charge.go @@ -139,7 +139,7 @@ func (k Keeper) applyFrozenUserFlows(ctx sdk.Context, userFlows types.UserFlows, } streamRecordChange := types.NewDefaultStreamRecordChangeWithAddr(from). WithRateChange(totalActiveRate.Neg()).WithFrozenRateChange(totalFrozenRate.Neg()) - err := k.UpdateFrozenStreamRecord(ctx, streamRecord, streamRecordChange) + err := k.UpdateStreamRecord(ctx, streamRecord, streamRecordChange) if err != nil { return fmt.Errorf("apply stream record changes for user failed: %w", err) } diff --git a/x/payment/keeper/storage_fee_charge_test.go b/x/payment/keeper/storage_fee_charge_test.go index 7faf16db7..018a6cc0d 100644 --- a/x/payment/keeper/storage_fee_charge_test.go +++ b/x/payment/keeper/storage_fee_charge_test.go @@ -1,6 +1,7 @@ package keeper_test import ( + "errors" "sort" "testing" "time" @@ -49,42 +50,6 @@ func TestApplyFlowChanges(t *testing.T) { t.Logf("sp stream record: %+v", spStreamRecord) } -func TestSettleStreamRecord(t *testing.T) { - keeper, ctx, _ := makePaymentKeeper(t) - ctx = ctx.WithBlockTime(time.Unix(100, 0)) - user := sample.RandAccAddress() - rate := sdkmath.NewInt(-100) - staticBalance := sdkmath.NewInt(1e10) - change := types.NewDefaultStreamRecordChangeWithAddr(user).WithRateChange(rate).WithStaticBalanceChange(staticBalance) - sr := &types.StreamRecord{Account: user.String(), - OutFlowCount: 1, - StaticBalance: sdkmath.ZeroInt(), - BufferBalance: sdkmath.ZeroInt(), - LockBalance: sdkmath.ZeroInt(), - NetflowRate: sdkmath.ZeroInt(), - FrozenNetflowRate: sdkmath.ZeroInt(), - } - keeper.SetStreamRecord(ctx, sr) - _, err := keeper.UpdateStreamRecordByAddr(ctx, change) - require.NoError(t, err) - // check - streamRecord, found := keeper.GetStreamRecord(ctx, user) - require.True(t, found) - t.Logf("stream record: %+v", streamRecord) - // 345 seconds pass - var seconds int64 = 345 - ctx = ctx.WithBlockTime(ctx.BlockTime().Add(time.Duration(seconds) * time.Second)) - change = types.NewDefaultStreamRecordChangeWithAddr(user) - _, err = keeper.UpdateStreamRecordByAddr(ctx, change) - require.NoError(t, err) - userStreamRecord2, _ := keeper.GetStreamRecord(ctx, user) - t.Logf("stream record after %d seconds: %+v", seconds, userStreamRecord2) - require.Equal(t, userStreamRecord2.StaticBalance, streamRecord.StaticBalance.Add(rate.Mul(sdkmath.NewInt(seconds)))) - require.Equal(t, userStreamRecord2.BufferBalance, streamRecord.BufferBalance) - require.Equal(t, userStreamRecord2.NetflowRate, streamRecord.NetflowRate) - require.Equal(t, userStreamRecord2.CrudTimestamp, streamRecord.CrudTimestamp+seconds) -} - func TestMergeStreamRecordChanges(t *testing.T) { users := []sdk.AccAddress{ sample.RandAccAddress(), @@ -116,96 +81,180 @@ func TestMergeStreamRecordChanges(t *testing.T) { }) } -func TestAutoForceSettle(t *testing.T) { - keeper, ctx, depKeepers := makePaymentKeeper(t) - t.Logf("depKeepers: %+v", depKeepers) - params := keeper.GetParams(ctx) - var startTime int64 = 100 - ctx = ctx.WithBlockTime(time.Unix(startTime, 0)) - user := sample.RandAccAddress() - rate := sdkmath.NewInt(100) - sp := sample.RandAccAddress() - userInitBalance := sdkmath.NewInt(int64(100*params.VersionedParams.ReserveTime) + 1) // just enough for reserve - // init balance - streamRecordChanges := []types.StreamRecordChange{ - *types.NewDefaultStreamRecordChangeWithAddr(user).WithStaticBalanceChange(userInitBalance), +func TestApplyUserFlows_ActiveStreamRecord(t *testing.T) { + keeper, ctx, deepKeepers := makePaymentKeeper(t) + ctx = ctx.WithIsCheckTx(true) + + from := sample.RandAccAddress() + userFlows := types.UserFlows{ + From: from, } - err := keeper.ApplyStreamRecordChanges(ctx, streamRecordChanges) - require.NoError(t, err) - userStreamRecord, found := keeper.GetStreamRecord(ctx, user) - t.Logf("user stream record: %+v", userStreamRecord) - require.True(t, found) - flowChanges := []types.OutFlow{ - {ToAddress: sp.String(), Rate: rate}, + + toAddr1 := sample.RandAccAddress() + outFlow1 := types.OutFlow{ + ToAddress: toAddr1.String(), + Rate: sdkmath.NewInt(100), + } + userFlows.Flows = append(userFlows.Flows, outFlow1) + + toAddr2 := sample.RandAccAddress() + outFlow2 := types.OutFlow{ + ToAddress: toAddr2.String(), + Rate: sdkmath.NewInt(200), } - userFlows := types.UserFlows{Flows: flowChanges, From: user} + userFlows.Flows = append(userFlows.Flows, outFlow2) + + // no bank account + deepKeepers.AccountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()). + Return(false).Times(1) + err := keeper.ApplyUserFlowsList(ctx, []types.UserFlows{userFlows}) + require.ErrorContains(t, err, "balance not enough") + + // has bank account, but balance is not enough + deepKeepers.AccountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()). + Return(true).AnyTimes() + deepKeepers.BankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(errors.New("transfer error")).Times(1) + err = keeper.ApplyUserFlowsList(ctx, []types.UserFlows{userFlows}) + require.ErrorContains(t, err, "balance not enough") + + // has bank account, and balance is enough + deepKeepers.BankKeeper.EXPECT().SendCoinsFromAccountToModule(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil).AnyTimes() err = keeper.ApplyUserFlowsList(ctx, []types.UserFlows{userFlows}) require.NoError(t, err) - userStreamRecord, found = keeper.GetStreamRecord(ctx, user) - t.Logf("user stream record: %+v", userStreamRecord) - require.True(t, found) - outFlows := keeper.GetOutFlows(ctx, user) - require.Equal(t, 1, len(outFlows)) - require.Equal(t, outFlows[0].ToAddress, sp.String()) - spStreamRecord, found := keeper.GetStreamRecord(ctx, sp) - t.Logf("sp stream record: %+v", spStreamRecord) - require.True(t, found) - require.Equal(t, spStreamRecord.NetflowRate, rate) - require.Equal(t, spStreamRecord.StaticBalance, sdkmath.ZeroInt()) - require.Equal(t, spStreamRecord.BufferBalance, sdkmath.ZeroInt()) - // check auto settle queue - autoSettleQueue := keeper.GetAllAutoSettleRecord(ctx) - t.Logf("auto settle queue: %+v", autoSettleQueue) - require.Equal(t, len(autoSettleQueue), 1) - require.Equal(t, autoSettleQueue[0].Addr, user.String()) - require.Equal(t, autoSettleQueue[0].Timestamp, startTime+int64(params.VersionedParams.ReserveTime)-int64(params.ForcedSettleTime)) - // 1 day pass - ctx = ctx.WithBlockTime(ctx.BlockTime().Add(time.Duration(86400) * time.Second)) - // update and deposit to user for extra 100s - depKeepers.AccountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()).Return(false).AnyTimes() - userAddBalance := rate.MulRaw(100) - change := types.NewDefaultStreamRecordChangeWithAddr(user).WithStaticBalanceChange(userAddBalance) - ret, err := keeper.UpdateStreamRecordByAddr(ctx, change) - require.NoError(t, err) - userStreamRecord = ret - t.Logf("user stream record: %+v", userStreamRecord) - require.True(t, found) - require.True(t, userStreamRecord.StaticBalance.IsNegative()) - change = types.NewDefaultStreamRecordChangeWithAddr(sp) - _, err = keeper.UpdateStreamRecordByAddr(ctx, change) - require.NoError(t, err) - spStreamRecord, _ = keeper.GetStreamRecord(ctx, sp) - t.Logf("sp stream record: %+v", spStreamRecord) - autoSettleQueue2 := keeper.GetAllAutoSettleRecord(ctx) - t.Logf("auto settle queue: %+v", autoSettleQueue2) - require.Equal(t, autoSettleQueue[0].Timestamp+100, autoSettleQueue2[0].Timestamp) - // reserve time - forced settle time - 1 day + 101s pass - ctx = ctx.WithBlockTime(ctx.BlockTime().Add(time.Duration(params.VersionedParams.ReserveTime-params.ForcedSettleTime-86400+101) * time.Second)) - usrBeforeForceSettle, _ := keeper.GetStreamRecord(ctx, user) - t.Logf("usrBeforeForceSettle: %s", usrBeforeForceSettle) - ctx = ctx.WithValue(types.ForceUpdateStreamRecordKey, true) - time.Sleep(1 * time.Second) - keeper.AutoSettle(ctx) + fromRecord, _ := keeper.GetStreamRecord(ctx, from) + require.True(t, fromRecord.Status == types.STREAM_ACCOUNT_STATUS_ACTIVE) + require.True(t, fromRecord.NetflowRate.Int64() == -300) + require.True(t, fromRecord.StaticBalance.Int64() == 0) + require.True(t, fromRecord.FrozenNetflowRate.Int64() == 0) + require.True(t, fromRecord.LockBalance.Int64() == 0) + require.True(t, fromRecord.BufferBalance.Int64() > 0) - usrAfterForceSettle, found := keeper.GetStreamRecord(ctx, user) - require.True(t, found) - t.Logf("usrAfterForceSettle: %s", usrAfterForceSettle) - // user has been force settled - require.Equal(t, usrAfterForceSettle.StaticBalance, sdkmath.ZeroInt()) - require.Equal(t, usrAfterForceSettle.BufferBalance, sdkmath.ZeroInt()) - require.Equal(t, usrAfterForceSettle.NetflowRate, sdkmath.ZeroInt()) - require.Equal(t, usrAfterForceSettle.Status, types.STREAM_ACCOUNT_STATUS_FROZEN) - change = types.NewDefaultStreamRecordChangeWithAddr(sp) - _, err = keeper.UpdateStreamRecordByAddr(ctx, change) + to1Record, _ := keeper.GetStreamRecord(ctx, toAddr1) + require.True(t, to1Record.Status == types.STREAM_ACCOUNT_STATUS_ACTIVE) + require.True(t, to1Record.NetflowRate.Int64() == 100) + require.True(t, to1Record.StaticBalance.Int64() == 0) + require.True(t, to1Record.FrozenNetflowRate.Int64() == 0) + require.True(t, to1Record.LockBalance.Int64() == 0) + require.True(t, to1Record.BufferBalance.Int64() == 0) + + to2Record, _ := keeper.GetStreamRecord(ctx, toAddr2) + require.True(t, to2Record.Status == types.STREAM_ACCOUNT_STATUS_ACTIVE) + require.True(t, to2Record.NetflowRate.Int64() == 200) + require.True(t, to2Record.StaticBalance.Int64() == 0) + require.True(t, to2Record.FrozenNetflowRate.Int64() == 0) + require.True(t, to2Record.LockBalance.Int64() == 0) + require.True(t, to2Record.BufferBalance.Int64() == 0) +} + +func TestApplyUserFlows_Frozen(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + + from := sample.RandAccAddress() + toAddr1 := sample.RandAccAddress() + toAddr2 := sample.RandAccAddress() + + // the account is frozen, and during auto settle or auto resume + fromStreamRecord := types.NewStreamRecord(from, ctx.BlockTime().Unix()) + fromStreamRecord.Status = types.STREAM_ACCOUNT_STATUS_FROZEN + fromStreamRecord.NetflowRate = sdkmath.NewInt(-100) + fromStreamRecord.FrozenNetflowRate = sdkmath.NewInt(-200) + fromStreamRecord.StaticBalance = sdkmath.ZeroInt() + fromStreamRecord.OutFlowCount = 4 + keeper.SetStreamRecord(ctx, fromStreamRecord) + + keeper.SetOutFlow(ctx, from, &types.OutFlow{ + ToAddress: toAddr1.String(), + Rate: sdkmath.NewInt(40), + Status: types.OUT_FLOW_STATUS_ACTIVE, + }) + keeper.SetOutFlow(ctx, from, &types.OutFlow{ + ToAddress: sample.RandAccAddress().String(), + Rate: sdkmath.NewInt(60), + Status: types.OUT_FLOW_STATUS_ACTIVE, + }) + keeper.SetOutFlow(ctx, from, &types.OutFlow{ + ToAddress: toAddr2.String(), + Rate: sdkmath.NewInt(120), + Status: types.OUT_FLOW_STATUS_FROZEN, + }) + keeper.SetOutFlow(ctx, from, &types.OutFlow{ + ToAddress: sample.RandAccAddress().String(), + Rate: sdkmath.NewInt(80), + Status: types.OUT_FLOW_STATUS_FROZEN, + }) + + to1StreamRecord := types.NewStreamRecord(toAddr1, ctx.BlockTime().Unix()) + to1StreamRecord.NetflowRate = sdkmath.NewInt(300) + to1StreamRecord.StaticBalance = sdkmath.NewInt(300) + keeper.SetStreamRecord(ctx, to1StreamRecord) + + to2StreamRecord := types.NewStreamRecord(toAddr2, ctx.BlockTime().Unix()) + to2StreamRecord.NetflowRate = sdkmath.NewInt(400) + to2StreamRecord.StaticBalance = sdkmath.NewInt(400) + keeper.SetStreamRecord(ctx, to2StreamRecord) + + userFlows := types.UserFlows{ + From: from, + } + + outFlow1 := types.OutFlow{ + ToAddress: toAddr1.String(), + Rate: sdkmath.NewInt(-40), + } + userFlows.Flows = append(userFlows.Flows, outFlow1) + + outFlow2 := types.OutFlow{ + ToAddress: toAddr2.String(), + Rate: sdkmath.NewInt(-60), + } + userFlows.Flows = append(userFlows.Flows, outFlow2) + + // update frozen stream record needs force flag + err := keeper.ApplyUserFlowsList(ctx, []types.UserFlows{userFlows}) + require.ErrorContains(t, err, "frozen") + + ctx = ctx.WithValue(types.ForceUpdateStreamRecordKey, true) + err = keeper.ApplyUserFlowsList(ctx, []types.UserFlows{userFlows}) require.NoError(t, err) - spStreamRecord, _ = keeper.GetStreamRecord(ctx, sp) - t.Logf("sp stream record: %+v", spStreamRecord) - autoSettleQueue3 := keeper.GetAllAutoSettleRecord(ctx) - t.Logf("auto settle queue: %+v", autoSettleQueue3) - require.Equal(t, len(autoSettleQueue3), 0) - govStreamRecord, found := keeper.GetStreamRecord(ctx, types.GovernanceAddress) - require.True(t, found) - t.Logf("gov stream record: %+v", govStreamRecord) - require.Equal(t, govStreamRecord.StaticBalance.Add(spStreamRecord.StaticBalance), userInitBalance.Add(userAddBalance)) + + fromRecord, _ := keeper.GetStreamRecord(ctx, from) + require.True(t, fromRecord.Status == types.STREAM_ACCOUNT_STATUS_FROZEN) + require.True(t, fromRecord.StaticBalance.Int64() == 0) + require.True(t, fromRecord.NetflowRate.Int64() == -60) + require.True(t, fromRecord.FrozenNetflowRate.Int64() == -140) + require.True(t, fromRecord.LockBalance.Int64() == 0) + require.True(t, fromRecord.BufferBalance.Int64() == 0) + + outFlows := keeper.GetOutFlows(ctx, from) + require.True(t, len(outFlows) == 3) + // the out flow to toAddr1 should be deleted + // the out flow to toAddr2 should be still there + to1Found := false + for _, outFlow := range outFlows { + if outFlow.ToAddress == toAddr1.String() { + to1Found = true + } + if outFlow.ToAddress == toAddr2.String() { + require.True(t, outFlow.Rate.Int64() == 60) + require.True(t, outFlow.Status == types.OUT_FLOW_STATUS_FROZEN) + } + } + require.True(t, !to1Found) + + to1Record, _ := keeper.GetStreamRecord(ctx, toAddr1) + require.True(t, to1Record.Status == types.STREAM_ACCOUNT_STATUS_ACTIVE) + require.True(t, to1Record.NetflowRate.Int64() == 260) + require.True(t, to1Record.FrozenNetflowRate.Int64() == 0) + require.True(t, to1Record.LockBalance.Int64() == 0) + require.True(t, to1Record.BufferBalance.Int64() == 0) + + to2Record, _ := keeper.GetStreamRecord(ctx, toAddr2) + require.True(t, to2Record.Status == types.STREAM_ACCOUNT_STATUS_ACTIVE) + require.True(t, to2Record.NetflowRate.Int64() == 400) // the outflow is frozen, which means the flow had been deduced + require.True(t, to2Record.FrozenNetflowRate.Int64() == 0) + require.True(t, to2Record.LockBalance.Int64() == 0) + require.True(t, to2Record.BufferBalance.Int64() == 0) } diff --git a/x/payment/keeper/stream_record.go b/x/payment/keeper/stream_record.go index a965f44e9..798016b06 100644 --- a/x/payment/keeper/stream_record.go +++ b/x/payment/keeper/stream_record.go @@ -114,6 +114,10 @@ func (k Keeper) GetAllStreamRecord(ctx sdk.Context) (list []types.StreamRecord) // it only handles the lock balance change and ignore the other changes(since the streams are already changed and the // accumulated OutFlows are changed outside this function) func (k Keeper) UpdateFrozenStreamRecord(ctx sdk.Context, streamRecord *types.StreamRecord, change *types.StreamRecordChange) error { + forced, _ := ctx.Value(types.ForceUpdateStreamRecordKey).(bool) + if !forced { + return fmt.Errorf("stream record %s is frozen", streamRecord.Account) + } // update lock balance if !change.LockBalanceChange.IsZero() { streamRecord.LockBalance = streamRecord.LockBalance.Add(change.LockBalanceChange) @@ -132,6 +136,10 @@ func (k Keeper) UpdateFrozenStreamRecord(ctx sdk.Context, streamRecord *types.St } func (k Keeper) UpdateStreamRecord(ctx sdk.Context, streamRecord *types.StreamRecord, change *types.StreamRecordChange) error { + if streamRecord.Status == types.STREAM_ACCOUNT_STATUS_FROZEN { + return k.UpdateFrozenStreamRecord(ctx, streamRecord, change) + } + forced, _ := ctx.Value(types.ForceUpdateStreamRecordKey).(bool) // force update in end block isPay := change.StaticBalanceChange.IsNegative() || change.RateChange.IsNegative() currentTimestamp := ctx.BlockTime().Unix() @@ -217,13 +225,11 @@ func (k Keeper) UpdateStreamRecordByAddr(ctx sdk.Context, change *types.StreamRe func (k Keeper) ForceSettle(ctx sdk.Context, streamRecord *types.StreamRecord) error { totalBalance := streamRecord.StaticBalance.Add(streamRecord.BufferBalance) - if totalBalance.IsPositive() { - change := types.NewDefaultStreamRecordChangeWithAddr(types.GovernanceAddress).WithStaticBalanceChange(totalBalance) - _, err := k.UpdateStreamRecordByAddr(ctx, change) - if err != nil { - telemetry.IncrCounter(1, types.GovernanceAddressLackBalanceLabel) - return fmt.Errorf("update governance stream record failed: %w", err) - } + change := types.NewDefaultStreamRecordChangeWithAddr(types.GovernanceAddress).WithStaticBalanceChange(totalBalance) + _, err := k.UpdateStreamRecordByAddr(ctx, change) + if err != nil { + telemetry.IncrCounter(1, types.GovernanceAddressLackBalanceLabel) + return fmt.Errorf("update governance stream record failed: %w", err) } // force settle streamRecord.StaticBalance = sdkmath.ZeroInt() diff --git a/x/payment/keeper/stream_record_test.go b/x/payment/keeper/stream_record_test.go index a726888d4..2f5c13855 100644 --- a/x/payment/keeper/stream_record_test.go +++ b/x/payment/keeper/stream_record_test.go @@ -752,3 +752,169 @@ func TestAutoSettle_SettleInMultipleBlocks_AutoResumeExists(t *testing.T) { require.Equal(t, gvg3StreamRecord.NetflowRate, sdk.NewInt(150)) require.Equal(t, gvg3StreamRecord.FrozenNetflowRate, sdkmath.ZeroInt()) } + +func TestUpdateStreamRecord_FrozenAccountLockBalance(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + ctx = ctx.WithBlockTime(time.Now()) + + user := sample.RandAccAddress() + streamRecord := &types.StreamRecord{ + StaticBalance: sdkmath.ZeroInt(), + BufferBalance: sdkmath.ZeroInt(), + LockBalance: sdkmath.NewInt(1000), + Account: user.String(), + Status: types.STREAM_ACCOUNT_STATUS_FROZEN, + NetflowRate: sdkmath.NewInt(0), + FrozenNetflowRate: sdkmath.NewInt(100).Neg(), + OutFlowCount: 1, + } + keeper.SetStreamRecord(ctx, streamRecord) + + // update fail when no force flag + change := types.NewDefaultStreamRecordChangeWithAddr(user). + WithLockBalanceChange(streamRecord.LockBalance.Neg()) + _, err := keeper.UpdateStreamRecordByAddr(ctx, change) + require.ErrorContains(t, err, "is frozen") + + // update success when there is force flag + ctx = ctx.WithValue(types.ForceUpdateStreamRecordKey, true) + change = types.NewDefaultStreamRecordChangeWithAddr(user). + WithLockBalanceChange(streamRecord.LockBalance.Neg()) + _, err = keeper.UpdateStreamRecordByAddr(ctx, change) + require.NoError(t, err) + + streamRecord, _ = keeper.GetStreamRecord(ctx, user) + require.True(t, streamRecord.Status == types.STREAM_ACCOUNT_STATUS_FROZEN) + require.True(t, streamRecord.LockBalance.IsZero()) + require.True(t, streamRecord.StaticBalance.Int64() == 1000) +} + +func TestSettleStreamRecord(t *testing.T) { + keeper, ctx, _ := makePaymentKeeper(t) + ctx = ctx.WithBlockTime(time.Unix(100, 0)) + user := sample.RandAccAddress() + rate := sdkmath.NewInt(-100) + staticBalance := sdkmath.NewInt(1e10) + change := types.NewDefaultStreamRecordChangeWithAddr(user).WithRateChange(rate).WithStaticBalanceChange(staticBalance) + sr := &types.StreamRecord{Account: user.String(), + OutFlowCount: 1, + StaticBalance: sdkmath.ZeroInt(), + BufferBalance: sdkmath.ZeroInt(), + LockBalance: sdkmath.ZeroInt(), + NetflowRate: sdkmath.ZeroInt(), + FrozenNetflowRate: sdkmath.ZeroInt(), + } + keeper.SetStreamRecord(ctx, sr) + _, err := keeper.UpdateStreamRecordByAddr(ctx, change) + require.NoError(t, err) + // check + streamRecord, found := keeper.GetStreamRecord(ctx, user) + require.True(t, found) + t.Logf("stream record: %+v", streamRecord) + // 345 seconds pass + var seconds int64 = 345 + ctx = ctx.WithBlockTime(ctx.BlockTime().Add(time.Duration(seconds) * time.Second)) + change = types.NewDefaultStreamRecordChangeWithAddr(user) + _, err = keeper.UpdateStreamRecordByAddr(ctx, change) + require.NoError(t, err) + userStreamRecord2, _ := keeper.GetStreamRecord(ctx, user) + t.Logf("stream record after %d seconds: %+v", seconds, userStreamRecord2) + require.Equal(t, userStreamRecord2.StaticBalance, streamRecord.StaticBalance.Add(rate.Mul(sdkmath.NewInt(seconds)))) + require.Equal(t, userStreamRecord2.BufferBalance, streamRecord.BufferBalance) + require.Equal(t, userStreamRecord2.NetflowRate, streamRecord.NetflowRate) + require.Equal(t, userStreamRecord2.CrudTimestamp, streamRecord.CrudTimestamp+seconds) +} + +func TestAutoForceSettle(t *testing.T) { + keeper, ctx, depKeepers := makePaymentKeeper(t) + t.Logf("depKeepers: %+v", depKeepers) + params := keeper.GetParams(ctx) + var startTime int64 = 100 + ctx = ctx.WithBlockTime(time.Unix(startTime, 0)) + user := sample.RandAccAddress() + rate := sdkmath.NewInt(100) + sp := sample.RandAccAddress() + userInitBalance := sdkmath.NewInt(int64(100*params.VersionedParams.ReserveTime) + 1) // just enough for reserve + // init balance + streamRecordChanges := []types.StreamRecordChange{ + *types.NewDefaultStreamRecordChangeWithAddr(user).WithStaticBalanceChange(userInitBalance), + } + err := keeper.ApplyStreamRecordChanges(ctx, streamRecordChanges) + require.NoError(t, err) + userStreamRecord, found := keeper.GetStreamRecord(ctx, user) + t.Logf("user stream record: %+v", userStreamRecord) + require.True(t, found) + flowChanges := []types.OutFlow{ + {ToAddress: sp.String(), Rate: rate}, + } + userFlows := types.UserFlows{Flows: flowChanges, From: user} + err = keeper.ApplyUserFlowsList(ctx, []types.UserFlows{userFlows}) + require.NoError(t, err) + userStreamRecord, found = keeper.GetStreamRecord(ctx, user) + t.Logf("user stream record: %+v", userStreamRecord) + require.True(t, found) + outFlows := keeper.GetOutFlows(ctx, user) + require.Equal(t, 1, len(outFlows)) + require.Equal(t, outFlows[0].ToAddress, sp.String()) + spStreamRecord, found := keeper.GetStreamRecord(ctx, sp) + t.Logf("sp stream record: %+v", spStreamRecord) + require.True(t, found) + require.Equal(t, spStreamRecord.NetflowRate, rate) + require.Equal(t, spStreamRecord.StaticBalance, sdkmath.ZeroInt()) + require.Equal(t, spStreamRecord.BufferBalance, sdkmath.ZeroInt()) + // check auto settle queue + autoSettleQueue := keeper.GetAllAutoSettleRecord(ctx) + t.Logf("auto settle queue: %+v", autoSettleQueue) + require.Equal(t, len(autoSettleQueue), 1) + require.Equal(t, autoSettleQueue[0].Addr, user.String()) + require.Equal(t, autoSettleQueue[0].Timestamp, startTime+int64(params.VersionedParams.ReserveTime)-int64(params.ForcedSettleTime)) + // 1 day pass + ctx = ctx.WithBlockTime(ctx.BlockTime().Add(time.Duration(86400) * time.Second)) + // update and deposit to user for extra 100s + depKeepers.AccountKeeper.EXPECT().HasAccount(gomock.Any(), gomock.Any()).Return(false).AnyTimes() + userAddBalance := rate.MulRaw(100) + change := types.NewDefaultStreamRecordChangeWithAddr(user).WithStaticBalanceChange(userAddBalance) + ret, err := keeper.UpdateStreamRecordByAddr(ctx, change) + require.NoError(t, err) + userStreamRecord = ret + t.Logf("user stream record: %+v", userStreamRecord) + require.True(t, found) + require.True(t, userStreamRecord.StaticBalance.IsNegative()) + change = types.NewDefaultStreamRecordChangeWithAddr(sp) + _, err = keeper.UpdateStreamRecordByAddr(ctx, change) + require.NoError(t, err) + spStreamRecord, _ = keeper.GetStreamRecord(ctx, sp) + t.Logf("sp stream record: %+v", spStreamRecord) + autoSettleQueue2 := keeper.GetAllAutoSettleRecord(ctx) + t.Logf("auto settle queue: %+v", autoSettleQueue2) + require.Equal(t, autoSettleQueue[0].Timestamp+100, autoSettleQueue2[0].Timestamp) + // reserve time - forced settle time - 1 day + 101s pass + ctx = ctx.WithBlockTime(ctx.BlockTime().Add(time.Duration(params.VersionedParams.ReserveTime-params.ForcedSettleTime-86400+101) * time.Second)) + usrBeforeForceSettle, _ := keeper.GetStreamRecord(ctx, user) + t.Logf("usrBeforeForceSettle: %s", usrBeforeForceSettle) + + ctx = ctx.WithValue(types.ForceUpdateStreamRecordKey, true) + time.Sleep(1 * time.Second) + keeper.AutoSettle(ctx) + + usrAfterForceSettle, found := keeper.GetStreamRecord(ctx, user) + require.True(t, found) + t.Logf("usrAfterForceSettle: %s", usrAfterForceSettle) + // user has been force settled + require.Equal(t, usrAfterForceSettle.StaticBalance, sdkmath.ZeroInt()) + require.Equal(t, usrAfterForceSettle.BufferBalance, sdkmath.ZeroInt()) + require.Equal(t, usrAfterForceSettle.NetflowRate, sdkmath.ZeroInt()) + require.Equal(t, usrAfterForceSettle.Status, types.STREAM_ACCOUNT_STATUS_FROZEN) + change = types.NewDefaultStreamRecordChangeWithAddr(sp) + _, err = keeper.UpdateStreamRecordByAddr(ctx, change) + require.NoError(t, err) + spStreamRecord, _ = keeper.GetStreamRecord(ctx, sp) + t.Logf("sp stream record: %+v", spStreamRecord) + autoSettleQueue3 := keeper.GetAllAutoSettleRecord(ctx) + t.Logf("auto settle queue: %+v", autoSettleQueue3) + require.Equal(t, len(autoSettleQueue3), 0) + govStreamRecord, found := keeper.GetStreamRecord(ctx, types.GovernanceAddress) + require.True(t, found) + t.Logf("gov stream record: %+v", govStreamRecord) + require.Equal(t, govStreamRecord.StaticBalance.Add(spStreamRecord.StaticBalance), userInitBalance.Add(userAddBalance)) +} diff --git a/x/payment/module.go b/x/payment/module.go index 3dc2e77e5..0c3abf803 100644 --- a/x/payment/module.go +++ b/x/payment/module.go @@ -81,7 +81,7 @@ func (a AppModuleBasic) GetTxCmd() *cobra.Command { // GetQueryCmd returns the root query command for the module. The subcommands of this root command are used by end-users to generate new queries to the subset of the state defined by the module func (AppModuleBasic) GetQueryCmd() *cobra.Command { - return cli.GetQueryCmd(types.StoreKey) + return cli.GetQueryCmd() } // ---------------------------------------------------------------------------- diff --git a/x/payment/module_simulation.go b/x/payment/module_simulation.go index 62b240343..99746bd8a 100644 --- a/x/payment/module_simulation.go +++ b/x/payment/module_simulation.go @@ -15,7 +15,7 @@ import ( // avoid unused import issue var ( - _ = sample.AccAddress + _ = sample.RandAccAddressHex _ = paymentsimulation.FindAccount _ = simulation.MsgEntryKind _ = baseapp.Paramspace diff --git a/x/permission/module_simulation.go b/x/permission/module_simulation.go index b62b7b00f..5a1bd3239 100644 --- a/x/permission/module_simulation.go +++ b/x/permission/module_simulation.go @@ -14,7 +14,7 @@ import ( // avoid unused import issue var ( - _ = sample.AccAddress + _ = sample.RandAccAddressHex _ = permissionsimulation.FindAccount _ = simulation.MsgEntryKind _ = baseapp.Paramspace diff --git a/x/sp/keeper/msg_server_test.go b/x/sp/keeper/msg_server_test.go index e0085691c..a069bd3cc 100644 --- a/x/sp/keeper/msg_server_test.go +++ b/x/sp/keeper/msg_server_test.go @@ -52,7 +52,7 @@ func (s *KeeperTestSuite) TestMsgCreateStorageProvider() { Identity: "", }, SpAddress: operatorAddr.String(), - FundingAddress: sample.AccAddress(), + FundingAddress: sample.RandAccAddressHex(), SealAddress: sealAddr.String(), ApprovalAddress: approvalAddr.String(), GcAddress: gcAddr.String(), diff --git a/x/sp/keeper/sp_test.go b/x/sp/keeper/sp_test.go index a92f2d150..64b2ec52f 100644 --- a/x/sp/keeper/sp_test.go +++ b/x/sp/keeper/sp_test.go @@ -17,7 +17,7 @@ func (s *KeeperTestSuite) TestSetGetStorageProvider() { keeper := s.spKeeper ctx := s.ctx sp := &types.StorageProvider{Id: 100} - spAccStr := sample.AccAddress() + spAccStr := sample.RandAccAddressHex() spAcc := sdk.MustAccAddressFromHex(spAccStr) sp.OperatorAddress = spAcc.String() @@ -34,16 +34,16 @@ func (s *KeeperTestSuite) TestSetGetStorageProvider() { func (s *KeeperTestSuite) TestStorageProviderBasics() { k := s.spKeeper ctx := s.ctx - spAccStr := sample.AccAddress() + spAccStr := sample.RandAccAddressHex() spAcc := sdk.MustAccAddressFromHex(spAccStr) - fundingAccStr := sample.AccAddress() + fundingAccStr := sample.RandAccAddressHex() fundingAcc := sdk.MustAccAddressFromHex(fundingAccStr) - sealAccStr := sample.AccAddress() + sealAccStr := sample.RandAccAddressHex() sealAcc := sdk.MustAccAddressFromHex(sealAccStr) - approvalAccStr := sample.AccAddress() + approvalAccStr := sample.RandAccAddressHex() approvalAcc := sdk.MustAccAddressFromHex(approvalAccStr) blsPubKey := sample.RandBlsPubKey() @@ -98,16 +98,16 @@ func (s *KeeperTestSuite) TestSlashBasic() { k := s.spKeeper ctx := s.ctx - spAccStr := sample.AccAddress() + spAccStr := sample.RandAccAddressHex() spAcc := sdk.MustAccAddressFromHex(spAccStr) - fundingAccStr := sample.AccAddress() + fundingAccStr := sample.RandAccAddressHex() fundingAcc := sdk.MustAccAddressFromHex(fundingAccStr) - sealAccStr := sample.AccAddress() + sealAccStr := sample.RandAccAddressHex() sealAcc := sdk.MustAccAddressFromHex(sealAccStr) - approvalAccStr := sample.AccAddress() + approvalAccStr := sample.RandAccAddressHex() approvalAcc := sdk.MustAccAddressFromHex(approvalAccStr) blsPubKey := sample.RandBlsPubKey() @@ -130,7 +130,7 @@ func (s *KeeperTestSuite) TestSlashBasic() { require.EqualValues(s.T(), found, true) rewardInfo := types.RewardInfo{ - Address: sample.AccAddress(), + Address: sample.RandAccAddressHex(), Amount: sdk.NewCoin(types2.Denom, math.NewIntWithDecimal(10, types2.DecimalBNB)), } diff --git a/x/sp/module_simulation.go b/x/sp/module_simulation.go index 369c1a97e..35d2a7207 100644 --- a/x/sp/module_simulation.go +++ b/x/sp/module_simulation.go @@ -14,7 +14,7 @@ import ( // avoid unused import issue var ( - _ = sample.AccAddress + _ = sample.RandAccAddressHex _ = spsimulation.FindAccount _ = simulation.MsgEntryKind _ = baseapp.Paramspace diff --git a/x/storage/keeper/cross_app_bucket.go b/x/storage/keeper/cross_app_bucket.go index a3d85a08d..6b6fe84de 100644 --- a/x/storage/keeper/cross_app_bucket.go +++ b/x/storage/keeper/cross_app_bucket.go @@ -13,10 +13,10 @@ import ( var _ sdk.CrossChainApplication = &BucketApp{} type BucketApp struct { - storageKeeper Keeper + storageKeeper types.StorageKeeper } -func NewBucketApp(keeper Keeper) *BucketApp { +func NewBucketApp(keeper types.StorageKeeper) *BucketApp { return &BucketApp{ storageKeeper: keeper, } @@ -216,7 +216,7 @@ func (app *BucketApp) handleCreateBucketSynPackage(ctx sdk.Context, appCtx *sdk. createBucketPackage.Creator, createBucketPackage.BucketName, createBucketPackage.PrimarySpAddress, - &CreateBucketOptions{ + &types.CreateBucketOptions{ Visibility: types.VisibilityType(createBucketPackage.Visibility), SourceType: types.SOURCE_TYPE_BSC_CROSS_CHAIN, ChargedReadQuota: createBucketPackage.ChargedReadQuota, @@ -267,6 +267,7 @@ func (app *BucketApp) handleDeleteBucketSynPackage(ctx sdk.Context, header *sdk. return sdk.ExecuteResult{ Payload: types.DeleteBucketAckPackage{ Status: types.StatusFail, + Id: deleteBucketPackage.Id, ExtraData: deleteBucketPackage.ExtraData, }.MustSerialize(), Err: err, @@ -281,6 +282,7 @@ func (app *BucketApp) handleDeleteBucketSynPackage(ctx sdk.Context, header *sdk. return sdk.ExecuteResult{ Payload: types.DeleteBucketAckPackage{ Status: types.StatusFail, + Id: deleteBucketPackage.Id, ExtraData: deleteBucketPackage.ExtraData, }.MustSerialize(), Err: types.ErrNoSuchBucket, @@ -290,7 +292,7 @@ func (app *BucketApp) handleDeleteBucketSynPackage(ctx sdk.Context, header *sdk. err = app.storageKeeper.DeleteBucket(ctx, deleteBucketPackage.Operator, bucketInfo.BucketName, - DeleteBucketOptions{ + types.DeleteBucketOptions{ SourceType: types.SOURCE_TYPE_BSC_CROSS_CHAIN, }, ) @@ -298,6 +300,7 @@ func (app *BucketApp) handleDeleteBucketSynPackage(ctx sdk.Context, header *sdk. return sdk.ExecuteResult{ Payload: types.DeleteBucketAckPackage{ Status: types.StatusFail, + Id: deleteBucketPackage.Id, ExtraData: deleteBucketPackage.ExtraData, }.MustSerialize(), Err: err, diff --git a/x/storage/keeper/cross_app_bucket_test.go b/x/storage/keeper/cross_app_bucket_test.go new file mode 100644 index 000000000..c4e9f35df --- /dev/null +++ b/x/storage/keeper/cross_app_bucket_test.go @@ -0,0 +1,57 @@ +package keeper_test + +import ( + "fmt" + "math/big" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/golang/mock/gomock" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/x/storage/keeper" + "github.com/bnb-chain/greenfield/x/storage/types" +) + +func (s *TestSuite) TestSynDeleteBucket() { + pack := types.DeleteBucketAckPackage{ + Status: 1, + Id: big.NewInt(10), + ExtraData: []byte("x"), + } + pack.MustSerialize() + ctrl := gomock.NewController(s.T()) + storageKeeper := types.NewMockStorageKeeper(ctrl) + storageKeeper.EXPECT().Logger(gomock.Any()).Return(s.ctx.Logger()).AnyTimes() + + app := keeper.NewBucketApp(storageKeeper) + deleteSynPackage := types.DeleteBucketSynPackage{ + Operator: sample.RandAccAddress(), + Id: big.NewInt(10), + ExtraData: []byte("extra data"), + } + + serializedSynPackage := deleteSynPackage.MustSerialize() + serializedSynPackage = append([]byte{types.OperationDeleteBucket}, serializedSynPackage...) + + // case 1: bucket not found + storageKeeper.EXPECT().GetBucketInfoById(gomock.Any(), gomock.Any()).Return(nil, false) + res := app.ExecuteSynPackage(s.ctx, nil, serializedSynPackage) + s.Require().ErrorIs(res.Err, types.ErrNoSuchBucket) + + // case 2: delete bucket error + storageKeeper.EXPECT().GetBucketInfoById(gomock.Any(), gomock.Any()).Return(&types.BucketInfo{ + BucketName: "bucket", + }, true) + storageKeeper.EXPECT().DeleteBucket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(fmt.Errorf("delete error")) + res = app.ExecuteSynPackage(s.ctx, nil, serializedSynPackage) + s.Require().ErrorContains(res.Err, "delete error") + + // case 3: delete bucket success + storageKeeper.EXPECT().GetBucketInfoById(gomock.Any(), gomock.Any()).Return(&types.BucketInfo{ + BucketName: "bucket", + Id: sdk.NewUint(10), + }, true) + storageKeeper.EXPECT().DeleteBucket(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + res = app.ExecuteSynPackage(s.ctx, nil, serializedSynPackage) + s.Require().NoError(res.Err) +} diff --git a/x/storage/keeper/cross_app_group.go b/x/storage/keeper/cross_app_group.go index 0f09fcceb..77b8e27a0 100644 --- a/x/storage/keeper/cross_app_group.go +++ b/x/storage/keeper/cross_app_group.go @@ -12,10 +12,10 @@ import ( var _ sdk.CrossChainApplication = &GroupApp{} type GroupApp struct { - storageKeeper Keeper + storageKeeper types.StorageKeeper } -func NewGroupApp(keeper Keeper) *GroupApp { +func NewGroupApp(keeper types.StorageKeeper) *GroupApp { return &GroupApp{ storageKeeper: keeper, } @@ -78,7 +78,7 @@ func (app *GroupApp) ExecuteFailAckPackage(ctx sdk.Context, appCtx *sdk.CrossCha operationType = types.OperationDeleteGroup result = app.handleDeleteGroupFailAckPackage(ctx, appCtx, p) case *types.UpdateGroupMemberSynPackage: - operationType = types.OperationDeleteGroup + operationType = types.OperationUpdateGroupMember result = app.handleUpdateGroupMemberFailAckPackage(ctx, appCtx, p) default: panic("unknown cross chain ack package type") @@ -174,7 +174,7 @@ func (app *GroupApp) handleDeleteGroupSynPackage(ctx sdk.Context, header *sdk.Cr err = app.storageKeeper.DeleteGroup(ctx, deleteGroupPackage.Operator, groupInfo.GroupName, - DeleteGroupOptions{ + types.DeleteGroupOptions{ SourceType: types.SOURCE_TYPE_BSC_CROSS_CHAIN, }, ) @@ -226,7 +226,7 @@ func (app *GroupApp) handleCreateGroupSynPackage(ctx sdk.Context, header *sdk.Cr groupId, err := app.storageKeeper.CreateGroup(ctx, createGroupPackage.Creator, createGroupPackage.GroupName, - CreateGroupOptions{ + types.CreateGroupOptions{ SourceType: types.SOURCE_TYPE_BSC_CROSS_CHAIN, }, ) @@ -336,7 +336,7 @@ func (app *GroupApp) handleUpdateGroupMemberSynPackage(ctx sdk.Context, header * } } - options := UpdateGroupMemberOptions{ + options := types.UpdateGroupMemberOptions{ SourceType: types.SOURCE_TYPE_BSC_CROSS_CHAIN, } if updateGroupPackage.OperationType == types.OperationAddGroupMember { diff --git a/x/storage/keeper/cross_app_object.go b/x/storage/keeper/cross_app_object.go index 43b540e34..9acc60fe7 100644 --- a/x/storage/keeper/cross_app_object.go +++ b/x/storage/keeper/cross_app_object.go @@ -12,10 +12,10 @@ import ( var _ sdk.CrossChainApplication = &ObjectApp{} type ObjectApp struct { - storageKeeper Keeper + storageKeeper types.StorageKeeper } -func NewObjectApp(keeper Keeper) *ObjectApp { +func NewObjectApp(keeper types.StorageKeeper) *ObjectApp { return &ObjectApp{ storageKeeper: keeper, } @@ -211,7 +211,7 @@ func (app *ObjectApp) handleDeleteObjectSynPackage(ctx sdk.Context, header *sdk. deleteObjectPackage.Operator, objectInfo.BucketName, objectInfo.ObjectName, - DeleteObjectOptions{ + types.DeleteObjectOptions{ SourceType: types.SOURCE_TYPE_BSC_CROSS_CHAIN, }, ) diff --git a/x/storage/keeper/keeper.go b/x/storage/keeper/keeper.go index 438d3f6d7..b6a350e03 100644 --- a/x/storage/keeper/keeper.go +++ b/x/storage/keeper/keeper.go @@ -86,7 +86,7 @@ func (k Keeper) Logger(ctx sdk.Context) log.Logger { func (k Keeper) CreateBucket( ctx sdk.Context, ownerAcc sdk.AccAddress, bucketName string, - primarySpAcc sdk.AccAddress, opts *CreateBucketOptions) (sdkmath.Uint, error) { + primarySpAcc sdk.AccAddress, opts *types.CreateBucketOptions) (sdkmath.Uint, error) { store := ctx.KVStore(k.storeKey) // check if the bucket exist @@ -170,7 +170,17 @@ func (k Keeper) CreateBucket( return bucketInfo.Id, nil } -func (k Keeper) DeleteBucket(ctx sdk.Context, operator sdk.AccAddress, bucketName string, opts DeleteBucketOptions) error { +// StoreBucketInfo will store the bucket info +// It's designed to be used by the test cases to create a bucket. +func (k Keeper) StoreBucketInfo(ctx sdk.Context, bucketInfo *types.BucketInfo) { + store := ctx.KVStore(k.storeKey) + bucketKey := types.GetBucketKey(bucketInfo.BucketName) + bz := k.cdc.MustMarshal(bucketInfo) + store.Set(bucketKey, k.bucketSeq.EncodeSequence(bucketInfo.Id)) + store.Set(types.GetBucketByIDKey(bucketInfo.Id), bz) +} + +func (k Keeper) DeleteBucket(ctx sdk.Context, operator sdk.AccAddress, bucketName string, opts types.DeleteBucketOptions) error { bucketInfo, found := k.GetBucketInfo(ctx, bucketName) if !found { return types.ErrNoSuchBucket @@ -334,7 +344,7 @@ func (k Keeper) ForceDeleteBucket(ctx sdk.Context, bucketId sdkmath.Uint, cap ui return bucketDeleted, deleted, nil } -func (k Keeper) UpdateBucketInfo(ctx sdk.Context, operator sdk.AccAddress, bucketName string, opts UpdateBucketOptions) error { +func (k Keeper) UpdateBucketInfo(ctx sdk.Context, operator sdk.AccAddress, bucketName string, opts types.UpdateBucketOptions) error { bucketInfo, found := k.GetBucketInfo(ctx, bucketName) if !found { return types.ErrNoSuchBucket @@ -499,7 +509,7 @@ func (k Keeper) GetBucketInfoById(ctx sdk.Context, bucketId sdkmath.Uint) (*type func (k Keeper) CreateObject( ctx sdk.Context, operator sdk.AccAddress, bucketName, objectName string, - payloadSize uint64, opts CreateObjectOptions) (sdkmath.Uint, error) { + payloadSize uint64, opts types.CreateObjectOptions) (sdkmath.Uint, error) { store := ctx.KVStore(k.storeKey) // check payload size @@ -619,6 +629,29 @@ func (k Keeper) CreateObject( return objectInfo.Id, nil } +// StoreObjectInfo stores object related keys to KVStore, +// it's designed to be used in tests +func (k Keeper) StoreObjectInfo(ctx sdk.Context, objectInfo *types.ObjectInfo) { + store := ctx.KVStore(k.storeKey) + + objectKey := types.GetObjectKey(objectInfo.BucketName, objectInfo.ObjectName) + + obz := k.cdc.MustMarshal(objectInfo) + store.Set(objectKey, k.objectSeq.EncodeSequence(objectInfo.Id)) + store.Set(types.GetObjectByIDKey(objectInfo.Id), obz) +} + +// DeleteObjectInfo deletes object related keys from KVStore, +// it's designed to be used in tests +func (k Keeper) DeleteObjectInfo(ctx sdk.Context, objectInfo *types.ObjectInfo) { + store := ctx.KVStore(k.storeKey) + + objectKey := types.GetObjectKey(objectInfo.BucketName, objectInfo.ObjectName) + + store.Delete(objectKey) + store.Delete(types.GetObjectByIDKey(objectInfo.Id)) +} + func (k Keeper) SetObjectInfo(ctx sdk.Context, objectInfo *types.ObjectInfo) { store := ctx.KVStore(k.storeKey) @@ -742,7 +775,7 @@ func (k Keeper) SealObject( func (k Keeper) CancelCreateObject( ctx sdk.Context, operator sdk.AccAddress, - bucketName, objectName string, opts CancelCreateObjectOptions) error { + bucketName, objectName string, opts types.CancelCreateObjectOptions) error { store := ctx.KVStore(k.storeKey) bucketInfo, found := k.GetBucketInfo(ctx, bucketName) if !found { @@ -796,7 +829,7 @@ func (k Keeper) CancelCreateObject( } func (k Keeper) DeleteObject( - ctx sdk.Context, operator sdk.AccAddress, bucketName, objectName string, opts DeleteObjectOptions) error { + ctx sdk.Context, operator sdk.AccAddress, bucketName, objectName string, opts types.DeleteObjectOptions) error { bucketInfo, found := k.GetBucketInfo(ctx, bucketName) if !found { @@ -917,7 +950,7 @@ func (k Keeper) ForceDeleteObject(ctx sdk.Context, objectId sdkmath.Uint) error func (k Keeper) CopyObject( ctx sdk.Context, operator sdk.AccAddress, srcBucketName, srcObjectName, dstBucketName, dstObjectName string, - opts CopyObjectOptions) (sdkmath.Uint, error) { + opts types.CopyObjectOptions) (sdkmath.Uint, error) { store := ctx.KVStore(k.storeKey) @@ -1178,7 +1211,7 @@ func (k Keeper) UpdateObjectInfo(ctx sdk.Context, operator sdk.AccAddress, bucke func (k Keeper) CreateGroup( ctx sdk.Context, owner sdk.AccAddress, - groupName string, opts CreateGroupOptions) (sdkmath.Uint, error) { + groupName string, opts types.CreateGroupOptions) (sdkmath.Uint, error) { store := ctx.KVStore(k.storeKey) groupInfo := types.GroupInfo{ @@ -1255,11 +1288,7 @@ func (k Keeper) GetGroupInfoById(ctx sdk.Context, groupId sdkmath.Uint) (*types. return &groupInfo, true } -type DeleteGroupOptions struct { - SourceType types.SourceType -} - -func (k Keeper) DeleteGroup(ctx sdk.Context, operator sdk.AccAddress, groupName string, opts DeleteGroupOptions) error { +func (k Keeper) DeleteGroup(ctx sdk.Context, operator sdk.AccAddress, groupName string, opts types.DeleteGroupOptions) error { store := ctx.KVStore(k.storeKey) groupInfo, found := k.GetGroupInfo(ctx, operator, groupName) @@ -1296,7 +1325,7 @@ func (k Keeper) DeleteGroup(ctx sdk.Context, operator sdk.AccAddress, groupName func (k Keeper) LeaveGroup( ctx sdk.Context, member sdk.AccAddress, owner sdk.AccAddress, - groupName string, opts LeaveGroupOptions) error { + groupName string, opts types.LeaveGroupOptions) error { groupInfo, found := k.GetGroupInfo(ctx, owner, groupName) if !found { @@ -1322,7 +1351,7 @@ func (k Keeper) LeaveGroup( return nil } -func (k Keeper) UpdateGroupMember(ctx sdk.Context, operator sdk.AccAddress, groupInfo *types.GroupInfo, opts UpdateGroupMemberOptions) error { +func (k Keeper) UpdateGroupMember(ctx sdk.Context, operator sdk.AccAddress, groupInfo *types.GroupInfo, opts types.UpdateGroupMemberOptions) error { if groupInfo.SourceType != opts.SourceType { return types.ErrSourceTypeMismatch } diff --git a/x/storage/keeper/keeper_object_test.go b/x/storage/keeper/keeper_object_test.go new file mode 100644 index 000000000..ba79025ba --- /dev/null +++ b/x/storage/keeper/keeper_object_test.go @@ -0,0 +1,220 @@ +package keeper_test + +import ( + "cosmossdk.io/math" + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/golang/mock/gomock" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/types/common" + types4 "github.com/bnb-chain/greenfield/x/payment/types" + types3 "github.com/bnb-chain/greenfield/x/sp/types" + "github.com/bnb-chain/greenfield/x/storage/types" + types2 "github.com/bnb-chain/greenfield/x/virtualgroup/types" +) + +func (s *TestSuite) TestCreateObject() { + operatorAddress := sample.RandAccAddress() + objectName := "objectName" + + bucketInfo := &types.BucketInfo{ + Owner: operatorAddress.String(), + BucketName: "bucketname", + Id: sdk.NewUint(1), + PaymentAddress: sample.RandAccAddress().String(), + ChargedReadQuota: 100, + BucketStatus: types.BUCKET_STATUS_CREATED, + } + + // case 1: bucket does not exist + _, err := s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, 100, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: nil, + ApprovalMsgBytes: nil, + }) + s.Require().ErrorContains(err, "No such bucket") + + // case 2: bucket is migrating + bucketInfo.BucketStatus = types.BUCKET_STATUS_MIGRATING + s.storageKeeper.StoreBucketInfo(s.ctx, bucketInfo) + _, err = s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, 100, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: nil, + ApprovalMsgBytes: nil, + }) + s.Require().ErrorContains(err, "the bucket is migrating") + + // case 3: bucket is discontinued + bucketInfo.BucketStatus = types.BUCKET_STATUS_DISCONTINUED + s.storageKeeper.StoreBucketInfo(s.ctx, bucketInfo) + _, err = s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, 100, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: nil, + ApprovalMsgBytes: nil, + }) + s.Require().ErrorContains(err, "the bucket is discontinued") + + // case 4: invalid payload size + _, err = s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, types.DefaultParams().MaxPayloadSize+1, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: nil, + ApprovalMsgBytes: nil, + }) + s.Require().ErrorContains(err, "Object payload size is too large") + + // case 4: gvg family does not exist + bucketInfo.BucketStatus = types.BUCKET_STATUS_CREATED + s.storageKeeper.StoreBucketInfo(s.ctx, bucketInfo) + s.virtualGroupKeeper.EXPECT().GetGVGFamily(gomock.Any(), gomock.Any()).Return(nil, false) + s.Require().Panics(func() { + _, _ = s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, 100, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: nil, + ApprovalMsgBytes: nil, + }) + }) + + // case 5: approval expired + s.virtualGroupKeeper.EXPECT().GetGVGFamily(gomock.Any(), gomock.Any()).Return(&types2.GlobalVirtualGroupFamily{ + Id: 0, + PrimarySpId: 0, + GlobalVirtualGroupIds: nil, + VirtualPaymentAddress: "", + }, true).AnyTimes() + + spAddress, signBytes, sig := sample.RandSignBytes() + s.spKeeper.EXPECT().MustGetStorageProvider(gomock.Any(), gomock.Any()).Return(&types3.StorageProvider{ + Id: 0, + OperatorAddress: spAddress.String(), + FundingAddress: "", + SealAddress: "", + ApprovalAddress: spAddress.String(), + GcAddress: "", + TotalDeposit: math.Int{}, + Status: 0, + Endpoint: "", + Description: types3.Description{}, + BlsKey: nil, + }).AnyTimes() + s.ctx = s.ctx.WithBlockHeight(100) + _, err = s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, 100, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: &common.Approval{ + ExpiredHeight: uint64(s.ctx.BlockHeight() - 1), + Sig: sig, + }, + ApprovalMsgBytes: signBytes, + }) + + s.Require().ErrorContains(err, "The approval of sp is expired") + + // case 6: invalid approval sig + _, err = s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, 100, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: &common.Approval{ + ExpiredHeight: uint64(s.ctx.BlockHeight() + 1), + Sig: []byte("invalid sig"), + }, + ApprovalMsgBytes: signBytes, + }) + s.Require().ErrorContains(err, "verify signature error") + + // case 7: object exist + s.storageKeeper.StoreObjectInfo(s.ctx, &types.ObjectInfo{ + Id: sdk.NewUint(1), + BucketName: bucketInfo.BucketName, + ObjectName: objectName, + }) + _, err = s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, 100, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: &common.Approval{ + ExpiredHeight: uint64(s.ctx.BlockHeight() + 1), + Sig: sig, + }, + ApprovalMsgBytes: signBytes, + }) + s.Require().ErrorContains(err, "Object already exists") + + // case 8: valid case + s.storageKeeper.DeleteObjectInfo(s.ctx, &types.ObjectInfo{ + Id: sdk.NewUint(1), + BucketName: bucketInfo.BucketName, + ObjectName: objectName, + }) + s.paymentKeeper.EXPECT().GetStoragePrice(gomock.Any(), gomock.Any()).Return(types4.StoragePrice{ + ReadPrice: sdk.NewDec(1), + PrimaryStorePrice: sdk.NewDec(2), + SecondaryStorePrice: sdk.NewDec(1), + }, nil).AnyTimes() + s.paymentKeeper.EXPECT().GetVersionedParamsWithTs(gomock.Any(), gomock.Any()).Return(types4.VersionedParams{ + ReserveTime: 10000, + ValidatorTaxRate: sdk.NewDec(1), + }, nil).AnyTimes() + s.paymentKeeper.EXPECT().UpdateStreamRecordByAddr(gomock.Any(), gomock.Any()).Return(&types4.StreamRecord{ + Account: "", + CrudTimestamp: 0, + NetflowRate: math.Int{}, + StaticBalance: sdk.NewInt(100), + BufferBalance: math.Int{}, + LockBalance: math.Int{}, + Status: 0, + SettleTimestamp: 0, + OutFlowCount: 0, + FrozenNetflowRate: math.Int{}, + }, nil).AnyTimes() + _, err = s.storageKeeper.CreateObject(s.ctx, operatorAddress, bucketInfo.BucketName, + objectName, 100, types.CreateObjectOptions{ + Visibility: 0, + ContentType: "", + SourceType: 0, + RedundancyType: 0, + Checksums: nil, + PrimarySpApproval: &common.Approval{ + ExpiredHeight: uint64(s.ctx.BlockHeight() + 1), + Sig: sig, + }, + ApprovalMsgBytes: signBytes, + }) + + s.Require().NoError(err) +} diff --git a/x/storage/keeper/msg_server.go b/x/storage/keeper/msg_server.go index f2da636d3..cd7277471 100644 --- a/x/storage/keeper/msg_server.go +++ b/x/storage/keeper/msg_server.go @@ -36,7 +36,7 @@ func (k msgServer) CreateBucket(goCtx context.Context, msg *types.MsgCreateBucke primarySPAcc := sdk.MustAccAddressFromHex(msg.PrimarySpAddress) - id, err := k.Keeper.CreateBucket(ctx, ownerAcc, msg.BucketName, primarySPAcc, &CreateBucketOptions{ + id, err := k.Keeper.CreateBucket(ctx, ownerAcc, msg.BucketName, primarySPAcc, &storagetypes.CreateBucketOptions{ PaymentAddress: msg.PaymentAddress, Visibility: msg.Visibility, ChargedReadQuota: msg.ChargedReadQuota, @@ -58,7 +58,7 @@ func (k msgServer) DeleteBucket(goCtx context.Context, msg *types.MsgDeleteBucke operatorAcc := sdk.MustAccAddressFromHex(msg.Operator) - err := k.Keeper.DeleteBucket(ctx, operatorAcc, msg.BucketName, DeleteBucketOptions{ + err := k.Keeper.DeleteBucket(ctx, operatorAcc, msg.BucketName, storagetypes.DeleteBucketOptions{ SourceType: types.SOURCE_TYPE_ORIGIN, }) if err != nil { @@ -76,7 +76,7 @@ func (k msgServer) UpdateBucketInfo(goCtx context.Context, msg *types.MsgUpdateB if msg.ChargedReadQuota != nil { chargedReadQuota = &msg.ChargedReadQuota.Value } - err := k.Keeper.UpdateBucketInfo(ctx, operatorAcc, msg.BucketName, UpdateBucketOptions{ + err := k.Keeper.UpdateBucketInfo(ctx, operatorAcc, msg.BucketName, storagetypes.UpdateBucketOptions{ SourceType: types.SOURCE_TYPE_ORIGIN, PaymentAddress: msg.PaymentAddress, Visibility: msg.Visibility, @@ -111,7 +111,7 @@ func (k msgServer) CreateObject(goCtx context.Context, msg *types.MsgCreateObjec len(msg.ExpectChecksums)) } - id, err := k.Keeper.CreateObject(ctx, ownerAcc, msg.BucketName, msg.ObjectName, msg.PayloadSize, CreateObjectOptions{ + id, err := k.Keeper.CreateObject(ctx, ownerAcc, msg.BucketName, msg.ObjectName, msg.PayloadSize, storagetypes.CreateObjectOptions{ SourceType: types.SOURCE_TYPE_ORIGIN, Visibility: msg.Visibility, ContentType: msg.ContentType, @@ -134,7 +134,7 @@ func (k msgServer) CancelCreateObject(goCtx context.Context, msg *types.MsgCance operatorAcc := sdk.MustAccAddressFromHex(msg.Operator) - err := k.Keeper.CancelCreateObject(ctx, operatorAcc, msg.BucketName, msg.ObjectName, CancelCreateObjectOptions{SourceType: types.SOURCE_TYPE_ORIGIN}) + err := k.Keeper.CancelCreateObject(ctx, operatorAcc, msg.BucketName, msg.ObjectName, storagetypes.CancelCreateObjectOptions{SourceType: types.SOURCE_TYPE_ORIGIN}) if err != nil { return nil, err } @@ -164,7 +164,7 @@ func (k msgServer) CopyObject(goCtx context.Context, msg *types.MsgCopyObject) ( ownerAcc := sdk.MustAccAddressFromHex(msg.Operator) - id, err := k.Keeper.CopyObject(ctx, ownerAcc, msg.SrcBucketName, msg.SrcObjectName, msg.DstBucketName, msg.DstObjectName, CopyObjectOptions{ + id, err := k.Keeper.CopyObject(ctx, ownerAcc, msg.SrcBucketName, msg.SrcObjectName, msg.DstBucketName, msg.DstObjectName, storagetypes.CopyObjectOptions{ SourceType: types.SOURCE_TYPE_ORIGIN, Visibility: storagetypes.VISIBILITY_TYPE_PRIVATE, PrimarySpApproval: msg.DstPrimarySpApproval, @@ -184,7 +184,7 @@ func (k msgServer) DeleteObject(goCtx context.Context, msg *types.MsgDeleteObjec operatorAcc := sdk.MustAccAddressFromHex(msg.Operator) - err := k.Keeper.DeleteObject(ctx, operatorAcc, msg.BucketName, msg.ObjectName, DeleteObjectOptions{ + err := k.Keeper.DeleteObject(ctx, operatorAcc, msg.BucketName, msg.ObjectName, storagetypes.DeleteObjectOptions{ SourceType: types.SOURCE_TYPE_ORIGIN, }) @@ -232,7 +232,7 @@ func (k msgServer) CreateGroup(goCtx context.Context, msg *types.MsgCreateGroup) ownerAcc := sdk.MustAccAddressFromHex(msg.Creator) - id, err := k.Keeper.CreateGroup(ctx, ownerAcc, msg.GroupName, CreateGroupOptions{Members: msg.Members, Extra: msg.Extra}) + id, err := k.Keeper.CreateGroup(ctx, ownerAcc, msg.GroupName, storagetypes.CreateGroupOptions{Members: msg.Members, Extra: msg.Extra}) if err != nil { return nil, err } @@ -246,7 +246,7 @@ func (k msgServer) DeleteGroup(goCtx context.Context, msg *types.MsgDeleteGroup) ctx := sdk.UnwrapSDKContext(goCtx) operatorAcc := sdk.MustAccAddressFromHex(msg.Operator) - err := k.Keeper.DeleteGroup(ctx, operatorAcc, msg.GroupName, DeleteGroupOptions{SourceType: types.SOURCE_TYPE_ORIGIN}) + err := k.Keeper.DeleteGroup(ctx, operatorAcc, msg.GroupName, storagetypes.DeleteGroupOptions{SourceType: types.SOURCE_TYPE_ORIGIN}) if err != nil { return nil, err } @@ -261,7 +261,7 @@ func (k msgServer) LeaveGroup(goCtx context.Context, msg *types.MsgLeaveGroup) ( ownerAcc := sdk.MustAccAddressFromHex(msg.GroupOwner) - err := k.Keeper.LeaveGroup(ctx, memberAcc, ownerAcc, msg.GroupName, LeaveGroupOptions{SourceType: types.SOURCE_TYPE_ORIGIN}) + err := k.Keeper.LeaveGroup(ctx, memberAcc, ownerAcc, msg.GroupName, storagetypes.LeaveGroupOptions{SourceType: types.SOURCE_TYPE_ORIGIN}) if err != nil { return nil, err } @@ -280,7 +280,7 @@ func (k msgServer) UpdateGroupMember(goCtx context.Context, msg *types.MsgUpdate if !found { return nil, types.ErrNoSuchGroup } - err := k.Keeper.UpdateGroupMember(ctx, operator, groupInfo, UpdateGroupMemberOptions{ + err := k.Keeper.UpdateGroupMember(ctx, operator, groupInfo, storagetypes.UpdateGroupMemberOptions{ SourceType: types.SOURCE_TYPE_ORIGIN, MembersToAdd: msg.MembersToAdd, MembersToDelete: msg.MembersToDelete, diff --git a/x/storage/module_simulation.go b/x/storage/module_simulation.go index 2262f412b..6d58c4a58 100644 --- a/x/storage/module_simulation.go +++ b/x/storage/module_simulation.go @@ -14,7 +14,7 @@ import ( // avoid unused import issue var ( - _ = sample.AccAddress + _ = sample.RandAccAddressHex _ = storagesimulation.FindAccount _ = simulation.MsgEntryKind _ = baseapp.Paramspace diff --git a/x/storage/types/crosschain.go b/x/storage/types/crosschain.go index a09c1ae2c..36938efd8 100644 --- a/x/storage/types/crosschain.go +++ b/x/storage/types/crosschain.go @@ -525,6 +525,18 @@ var ( } ) +func (p DeleteBucketSynPackage) MustSerialize() []byte { + encodedBytes, err := generalDeleteSynPackageArgs.Pack(&GeneralDeleteSynPackageStruct{ + Operator: common.BytesToAddress(p.Operator), + Id: SafeBigInt(p.Id), + ExtraData: p.ExtraData, + }) + if err != nil { + panic("encode delete bucket sync package error") + } + return encodedBytes +} + func (p DeleteBucketSynPackage) ValidateBasic() error { if p.Operator.Empty() { return sdkerrors.ErrInvalidAddress @@ -574,7 +586,7 @@ var ( ) func (p DeleteBucketAckPackage) MustSerialize() []byte { - encodedBytes, err := generalCreateAckPackageArgs.Pack(&DeleteBucketAckPackage{ + encodedBytes, err := generalDeleteAckPackageArgs.Pack(&DeleteBucketAckPackage{ p.Status, SafeBigInt(p.Id), p.ExtraData, diff --git a/x/storage/types/errors.go b/x/storage/types/errors.go index efe59093c..9fb387674 100644 --- a/x/storage/types/errors.go +++ b/x/storage/types/errors.go @@ -2,8 +2,6 @@ package types import ( "cosmossdk.io/errors" - - "github.com/bnb-chain/greenfield/x/virtualgroup/types" ) // x/storage module sentinel errors @@ -49,5 +47,5 @@ var ( ErrInvalidResource = errors.Register(ModuleName, 3201, "invalid resource type") ErrMigrationBucketFailed = errors.Register(ModuleName, 3202, "migrate bucket failed.") ErrVirtualGroupOperateFailed = errors.Register(ModuleName, 3203, "operate virtual group failed.") - ErrInvalidBlsPubKey = errors.Register(types.ModuleName, 1122, "invalid bls public key") + ErrInvalidBlsPubKey = errors.Register(ModuleName, 3204, "invalid bls public key") ) diff --git a/x/storage/types/expected_keepers.go b/x/storage/types/expected_keepers.go index 5bc76a07f..85b43af1f 100644 --- a/x/storage/types/expected_keepers.go +++ b/x/storage/types/expected_keepers.go @@ -4,6 +4,8 @@ import ( "math/big" "cosmossdk.io/math" + sdkmath "cosmossdk.io/math" + "github.com/cometbft/cometbft/libs/log" sdk "github.com/cosmos/cosmos-sdk/types" authtypes "github.com/cosmos/cosmos-sdk/x/auth/types" @@ -90,3 +92,25 @@ type VirtualGroupKeeper interface { GetAndCheckGVGFamilyAvailableForNewBucket(ctx sdk.Context, familyID uint32) (*types.GlobalVirtualGroupFamily, error) GetGlobalVirtualGroupIfAvailable(ctx sdk.Context, gvgID uint32, expectedStoreSize uint64) (*types.GlobalVirtualGroup, error) } + +// StorageKeeper used by the cross-chain applications +type StorageKeeper interface { + Logger(ctx sdk.Context) log.Logger + GetBucketInfoById(ctx sdk.Context, bucketId sdkmath.Uint) (*BucketInfo, bool) + SetBucketInfo(ctx sdk.Context, bucketInfo *BucketInfo) + CreateBucket( + ctx sdk.Context, ownerAcc sdk.AccAddress, bucketName string, + primarySpAcc sdk.AccAddress, opts *CreateBucketOptions) (sdkmath.Uint, error) + DeleteBucket(ctx sdk.Context, operator sdk.AccAddress, bucketName string, opts DeleteBucketOptions) error + GetGroupInfoById(ctx sdk.Context, groupId sdkmath.Uint) (*GroupInfo, bool) + DeleteGroup(ctx sdk.Context, operator sdk.AccAddress, groupName string, opts DeleteGroupOptions) error + CreateGroup( + ctx sdk.Context, owner sdk.AccAddress, + groupName string, opts CreateGroupOptions) (sdkmath.Uint, error) + SetGroupInfo(ctx sdk.Context, groupInfo *GroupInfo) + UpdateGroupMember(ctx sdk.Context, operator sdk.AccAddress, groupInfo *GroupInfo, opts UpdateGroupMemberOptions) error + GetObjectInfoById(ctx sdk.Context, objectId sdkmath.Uint) (*ObjectInfo, bool) + SetObjectInfo(ctx sdk.Context, objectInfo *ObjectInfo) + DeleteObject( + ctx sdk.Context, operator sdk.AccAddress, bucketName, objectName string, opts DeleteObjectOptions) error +} diff --git a/x/storage/types/expected_keepers_mocks.go b/x/storage/types/expected_keepers_mocks.go index 64ce48f30..12d9f9a5f 100644 --- a/x/storage/types/expected_keepers_mocks.go +++ b/x/storage/types/expected_keepers_mocks.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: x/storage/types/expected_keepers.go +// Source: expected_keepers.go // Package types is a generated GoMock package. package types @@ -9,14 +9,16 @@ import ( reflect "reflect" math "cosmossdk.io/math" + log "github.com/cometbft/cometbft/libs/log" + types3 "github.com/cosmos/cosmos-sdk/types" + types4 "github.com/cosmos/cosmos-sdk/x/auth/types" + gomock "github.com/golang/mock/gomock" + resource "github.com/bnb-chain/greenfield/types/resource" types "github.com/bnb-chain/greenfield/x/payment/types" types0 "github.com/bnb-chain/greenfield/x/permission/types" types1 "github.com/bnb-chain/greenfield/x/sp/types" types2 "github.com/bnb-chain/greenfield/x/virtualgroup/types" - types3 "github.com/cosmos/cosmos-sdk/types" - types4 "github.com/cosmos/cosmos-sdk/x/auth/types" - gomock "github.com/golang/mock/gomock" ) // MockAccountKeeper is a mock of AccountKeeper interface. @@ -846,3 +848,207 @@ func (mr *MockVirtualGroupKeeperMockRecorder) SettleAndDistributeGVGFamily(ctx, mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SettleAndDistributeGVGFamily", reflect.TypeOf((*MockVirtualGroupKeeper)(nil).SettleAndDistributeGVGFamily), ctx, sp, family) } + +// MockStorageKeeper is a mock of StorageKeeper interface. +type MockStorageKeeper struct { + ctrl *gomock.Controller + recorder *MockStorageKeeperMockRecorder +} + +// MockStorageKeeperMockRecorder is the mock recorder for MockStorageKeeper. +type MockStorageKeeperMockRecorder struct { + mock *MockStorageKeeper +} + +// NewMockStorageKeeper creates a new mock instance. +func NewMockStorageKeeper(ctrl *gomock.Controller) *MockStorageKeeper { + mock := &MockStorageKeeper{ctrl: ctrl} + mock.recorder = &MockStorageKeeperMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockStorageKeeper) EXPECT() *MockStorageKeeperMockRecorder { + return m.recorder +} + +// CreateBucket mocks base method. +func (m *MockStorageKeeper) CreateBucket(ctx types3.Context, ownerAcc types3.AccAddress, bucketName string, primarySpAcc types3.AccAddress, opts *CreateBucketOptions) (math.Uint, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateBucket", ctx, ownerAcc, bucketName, primarySpAcc, opts) + ret0, _ := ret[0].(math.Uint) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateBucket indicates an expected call of CreateBucket. +func (mr *MockStorageKeeperMockRecorder) CreateBucket(ctx, ownerAcc, bucketName, primarySpAcc, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateBucket", reflect.TypeOf((*MockStorageKeeper)(nil).CreateBucket), ctx, ownerAcc, bucketName, primarySpAcc, opts) +} + +// CreateGroup mocks base method. +func (m *MockStorageKeeper) CreateGroup(ctx types3.Context, owner types3.AccAddress, groupName string, opts CreateGroupOptions) (math.Uint, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateGroup", ctx, owner, groupName, opts) + ret0, _ := ret[0].(math.Uint) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateGroup indicates an expected call of CreateGroup. +func (mr *MockStorageKeeperMockRecorder) CreateGroup(ctx, owner, groupName, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateGroup", reflect.TypeOf((*MockStorageKeeper)(nil).CreateGroup), ctx, owner, groupName, opts) +} + +// DeleteBucket mocks base method. +func (m *MockStorageKeeper) DeleteBucket(ctx types3.Context, operator types3.AccAddress, bucketName string, opts DeleteBucketOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteBucket", ctx, operator, bucketName, opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteBucket indicates an expected call of DeleteBucket. +func (mr *MockStorageKeeperMockRecorder) DeleteBucket(ctx, operator, bucketName, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteBucket", reflect.TypeOf((*MockStorageKeeper)(nil).DeleteBucket), ctx, operator, bucketName, opts) +} + +// DeleteGroup mocks base method. +func (m *MockStorageKeeper) DeleteGroup(ctx types3.Context, operator types3.AccAddress, groupName string, opts DeleteGroupOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteGroup", ctx, operator, groupName, opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteGroup indicates an expected call of DeleteGroup. +func (mr *MockStorageKeeperMockRecorder) DeleteGroup(ctx, operator, groupName, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteGroup", reflect.TypeOf((*MockStorageKeeper)(nil).DeleteGroup), ctx, operator, groupName, opts) +} + +// DeleteObject mocks base method. +func (m *MockStorageKeeper) DeleteObject(ctx types3.Context, operator types3.AccAddress, bucketName, objectName string, opts DeleteObjectOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteObject", ctx, operator, bucketName, objectName, opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteObject indicates an expected call of DeleteObject. +func (mr *MockStorageKeeperMockRecorder) DeleteObject(ctx, operator, bucketName, objectName, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteObject", reflect.TypeOf((*MockStorageKeeper)(nil).DeleteObject), ctx, operator, bucketName, objectName, opts) +} + +// GetBucketInfoById mocks base method. +func (m *MockStorageKeeper) GetBucketInfoById(ctx types3.Context, bucketId math.Uint) (*BucketInfo, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetBucketInfoById", ctx, bucketId) + ret0, _ := ret[0].(*BucketInfo) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetBucketInfoById indicates an expected call of GetBucketInfoById. +func (mr *MockStorageKeeperMockRecorder) GetBucketInfoById(ctx, bucketId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetBucketInfoById", reflect.TypeOf((*MockStorageKeeper)(nil).GetBucketInfoById), ctx, bucketId) +} + +// GetGroupInfoById mocks base method. +func (m *MockStorageKeeper) GetGroupInfoById(ctx types3.Context, groupId math.Uint) (*GroupInfo, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupInfoById", ctx, groupId) + ret0, _ := ret[0].(*GroupInfo) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetGroupInfoById indicates an expected call of GetGroupInfoById. +func (mr *MockStorageKeeperMockRecorder) GetGroupInfoById(ctx, groupId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupInfoById", reflect.TypeOf((*MockStorageKeeper)(nil).GetGroupInfoById), ctx, groupId) +} + +// GetObjectInfoById mocks base method. +func (m *MockStorageKeeper) GetObjectInfoById(ctx types3.Context, objectId math.Uint) (*ObjectInfo, bool) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetObjectInfoById", ctx, objectId) + ret0, _ := ret[0].(*ObjectInfo) + ret1, _ := ret[1].(bool) + return ret0, ret1 +} + +// GetObjectInfoById indicates an expected call of GetObjectInfoById. +func (mr *MockStorageKeeperMockRecorder) GetObjectInfoById(ctx, objectId interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetObjectInfoById", reflect.TypeOf((*MockStorageKeeper)(nil).GetObjectInfoById), ctx, objectId) +} + +// Logger mocks base method. +func (m *MockStorageKeeper) Logger(ctx types3.Context) log.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Logger", ctx) + ret0, _ := ret[0].(log.Logger) + return ret0 +} + +// Logger indicates an expected call of Logger. +func (mr *MockStorageKeeperMockRecorder) Logger(ctx interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Logger", reflect.TypeOf((*MockStorageKeeper)(nil).Logger), ctx) +} + +// SetBucketInfo mocks base method. +func (m *MockStorageKeeper) SetBucketInfo(ctx types3.Context, bucketInfo *BucketInfo) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetBucketInfo", ctx, bucketInfo) +} + +// SetBucketInfo indicates an expected call of SetBucketInfo. +func (mr *MockStorageKeeperMockRecorder) SetBucketInfo(ctx, bucketInfo interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetBucketInfo", reflect.TypeOf((*MockStorageKeeper)(nil).SetBucketInfo), ctx, bucketInfo) +} + +// SetGroupInfo mocks base method. +func (m *MockStorageKeeper) SetGroupInfo(ctx types3.Context, groupInfo *GroupInfo) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetGroupInfo", ctx, groupInfo) +} + +// SetGroupInfo indicates an expected call of SetGroupInfo. +func (mr *MockStorageKeeperMockRecorder) SetGroupInfo(ctx, groupInfo interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetGroupInfo", reflect.TypeOf((*MockStorageKeeper)(nil).SetGroupInfo), ctx, groupInfo) +} + +// SetObjectInfo mocks base method. +func (m *MockStorageKeeper) SetObjectInfo(ctx types3.Context, objectInfo *ObjectInfo) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetObjectInfo", ctx, objectInfo) +} + +// SetObjectInfo indicates an expected call of SetObjectInfo. +func (mr *MockStorageKeeperMockRecorder) SetObjectInfo(ctx, objectInfo interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetObjectInfo", reflect.TypeOf((*MockStorageKeeper)(nil).SetObjectInfo), ctx, objectInfo) +} + +// UpdateGroupMember mocks base method. +func (m *MockStorageKeeper) UpdateGroupMember(ctx types3.Context, operator types3.AccAddress, groupInfo *GroupInfo, opts UpdateGroupMemberOptions) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateGroupMember", ctx, operator, groupInfo, opts) + ret0, _ := ret[0].(error) + return ret0 +} + +// UpdateGroupMember indicates an expected call of UpdateGroupMember. +func (mr *MockStorageKeeperMockRecorder) UpdateGroupMember(ctx, operator, groupInfo, opts interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateGroupMember", reflect.TypeOf((*MockStorageKeeper)(nil).UpdateGroupMember), ctx, operator, groupInfo, opts) +} diff --git a/x/storage/types/message.go b/x/storage/types/message.go index e2b937fe6..b55a6e44a 100644 --- a/x/storage/types/message.go +++ b/x/storage/types/message.go @@ -523,7 +523,7 @@ func (msg *MsgSealObject) ValidateBasic() error { } if len(msg.GetSecondarySpBlsAggSignatures()) != sdk.BLSSignatureLength { - return errors.Wrap(sdkerrors.ErrInvalidRequest, + return errors.Wrap(gnfderrors.ErrInvalidBlsSignature, fmt.Sprintf("length of signature should be %d", sdk.BLSSignatureLength), ) } diff --git a/x/storage/types/message_cancel_migrate_bucket_test.go b/x/storage/types/message_cancel_migrate_bucket_test.go index 9e79b0995..86897c824 100644 --- a/x/storage/types/message_cancel_migrate_bucket_test.go +++ b/x/storage/types/message_cancel_migrate_bucket_test.go @@ -25,7 +25,7 @@ func TestMsgCancelMigrateBucket_ValidateBasic(t *testing.T) { }, { name: "valid address", msg: MsgCancelMigrateBucket{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), BucketName: testBucketName, }, }, diff --git a/x/storage/types/message_complete_migrate_bucket_test.go b/x/storage/types/message_complete_migrate_bucket_test.go index 1ad3783d0..4a25973b9 100644 --- a/x/storage/types/message_complete_migrate_bucket_test.go +++ b/x/storage/types/message_complete_migrate_bucket_test.go @@ -25,7 +25,7 @@ func TestMsgCompleteMigrateBucket_ValidateBasic(t *testing.T) { }, { name: "valid address", msg: MsgCompleteMigrateBucket{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), BucketName: "bucketname", GlobalVirtualGroupFamilyId: 1, GvgMappings: []*GVGMapping{{1, 2, []byte("xxxxxxxxxxx")}}, diff --git a/x/storage/types/message_migrate_bucket_test.go b/x/storage/types/message_migrate_bucket_test.go index 84d11d107..c0c4f7581 100644 --- a/x/storage/types/message_migrate_bucket_test.go +++ b/x/storage/types/message_migrate_bucket_test.go @@ -26,7 +26,7 @@ func TestMsgMigrateBucket_ValidateBasic(t *testing.T) { }, { name: "valid address", msg: MsgMigrateBucket{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), BucketName: "bucketname", DstPrimarySpId: 1, DstPrimarySpApproval: &common.Approval{ExpiredHeight: 10, Sig: []byte("XXXTentacion")}, diff --git a/x/storage/types/message_object_test.go b/x/storage/types/message_object_test.go new file mode 100644 index 000000000..0a6d38a14 --- /dev/null +++ b/x/storage/types/message_object_test.go @@ -0,0 +1,567 @@ +package types + +import ( + "strings" + "testing" + + "cosmossdk.io/math" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/prysmaticlabs/prysm/crypto/bls" + "github.com/stretchr/testify/require" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/types/common" + gnfderrors "github.com/bnb-chain/greenfield/types/errors" +) + +func TestMsgCreateObject_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgCreateObject + err error + }{ + { + name: "normal", + msg: MsgCreateObject{ + Creator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: testObjectName, + PayloadSize: 1024, + Visibility: VISIBILITY_TYPE_PRIVATE, + ContentType: "content-type", + PrimarySpApproval: &common.Approval{}, + ExpectChecksums: [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()}, + }, + }, { + name: "invalid object name", + msg: MsgCreateObject{ + Creator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: "", + PayloadSize: 1024, + Visibility: VISIBILITY_TYPE_PRIVATE, + ContentType: "content-type", + PrimarySpApproval: &common.Approval{}, + ExpectChecksums: [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()}, + }, + err: gnfderrors.ErrInvalidObjectName, + }, { + name: "invalid object name", + msg: MsgCreateObject{ + Creator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: "../object", + PayloadSize: 1024, + Visibility: VISIBILITY_TYPE_PRIVATE, + ContentType: "content-type", + PrimarySpApproval: &common.Approval{}, + ExpectChecksums: [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()}, + }, + err: gnfderrors.ErrInvalidObjectName, + }, { + name: "invalid object name", + msg: MsgCreateObject{ + Creator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: "//object", + PayloadSize: 1024, + Visibility: VISIBILITY_TYPE_PRIVATE, + ContentType: "content-type", + PrimarySpApproval: &common.Approval{}, + ExpectChecksums: [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()}, + }, + err: gnfderrors.ErrInvalidObjectName, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgCancelCreateObject_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgCancelCreateObject + err error + }{ + { + name: "basic", + msg: MsgCancelCreateObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: testObjectName, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgDeleteObject_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgDeleteObject + err error + }{ + { + name: "normal", + msg: MsgDeleteObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: testObjectName, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgCopyObject_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgCopyObject + err error + }{ + { + name: "valid address", + msg: MsgCopyObject{ + Operator: sample.RandAccAddressHex(), + SrcBucketName: testBucketName, + SrcObjectName: testObjectName, + DstBucketName: "dst" + testBucketName, + DstObjectName: "dst" + testObjectName, + DstPrimarySpApproval: &common.Approval{ + ExpiredHeight: 100, + Sig: []byte("xxx"), + }, + }, + }, + { + name: "invalid address", + msg: MsgCopyObject{ + Operator: "invalid address", + SrcBucketName: testBucketName, + SrcObjectName: testObjectName, + DstBucketName: "dst" + testBucketName, + DstObjectName: "dst" + testObjectName, + DstPrimarySpApproval: &common.Approval{ + ExpiredHeight: 100, + Sig: []byte("xxx"), + }, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "empty approval", + msg: MsgCopyObject{ + Operator: sample.RandAccAddressHex(), + SrcBucketName: testBucketName, + SrcObjectName: testObjectName, + DstBucketName: "dst" + testBucketName, + DstObjectName: "dst" + testObjectName, + DstPrimarySpApproval: nil, + }, + err: ErrInvalidApproval, + }, + { + name: "invalid src bucket name", + msg: MsgCopyObject{ + Operator: sample.RandAccAddressHex(), + SrcBucketName: "1.1.1.1", + SrcObjectName: testObjectName, + DstBucketName: "dst" + testBucketName, + DstObjectName: "dst" + testObjectName, + DstPrimarySpApproval: &common.Approval{ + ExpiredHeight: 100, + Sig: []byte("xxx"), + }, + }, + err: gnfderrors.ErrInvalidBucketName, + }, + { + name: "invalid src object name", + msg: MsgCopyObject{ + Operator: sample.RandAccAddressHex(), + SrcBucketName: testBucketName, + SrcObjectName: "", + DstBucketName: "dst" + testBucketName, + DstObjectName: "dst" + testObjectName, + DstPrimarySpApproval: &common.Approval{ + ExpiredHeight: 100, + Sig: []byte("xxx"), + }, + }, + err: gnfderrors.ErrInvalidObjectName, + }, + { + name: "invalid dest bucket name", + msg: MsgCopyObject{ + Operator: sample.RandAccAddressHex(), + SrcBucketName: testBucketName, + SrcObjectName: testObjectName, + DstBucketName: "1.1.1.1", + DstObjectName: "dst" + testObjectName, + DstPrimarySpApproval: &common.Approval{ + ExpiredHeight: 100, + Sig: []byte("xxx"), + }, + }, + err: gnfderrors.ErrInvalidBucketName, + }, + { + name: "invalid dest object name", + msg: MsgCopyObject{ + Operator: sample.RandAccAddressHex(), + SrcBucketName: testBucketName, + SrcObjectName: testObjectName, + DstBucketName: "dst" + testBucketName, + DstObjectName: "", + DstPrimarySpApproval: &common.Approval{ + ExpiredHeight: 100, + Sig: []byte("xxx"), + }, + }, + err: gnfderrors.ErrInvalidObjectName, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgSealObject_ValidateBasic(t *testing.T) { + checksums := [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()} + blsSignDoc := NewSecondarySpSealObjectSignDoc("greenfield_9000-1", 1, math.NewUint(1), GenerateHash(checksums[:])).GetSignBytes() + blsPrivKey, _ := bls.RandKey() + aggSig := blsPrivKey.Sign(blsSignDoc[:]).Marshal() + tests := []struct { + name string + msg MsgSealObject + err error + }{ + { + name: "normal", + msg: MsgSealObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: testObjectName, + SecondarySpBlsAggSignatures: aggSig, + }, + }, + { + name: "invalid address", + msg: MsgSealObject{ + Operator: "invalid address", + BucketName: testBucketName, + ObjectName: testObjectName, + SecondarySpBlsAggSignatures: aggSig, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid bucket name", + msg: MsgSealObject{ + Operator: sample.RandAccAddressHex(), + BucketName: "1.1.1.1", + ObjectName: testObjectName, + SecondarySpBlsAggSignatures: aggSig, + }, + err: gnfderrors.ErrInvalidBucketName, + }, + { + name: "invalid object name", + msg: MsgSealObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: "", + SecondarySpBlsAggSignatures: aggSig, + }, + err: gnfderrors.ErrInvalidObjectName, + }, + { + name: "invalid signature", + msg: MsgSealObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: testObjectName, + SecondarySpBlsAggSignatures: []byte("invalid signature"), + }, + err: gnfderrors.ErrInvalidBlsSignature, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgRejectSealObject_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgRejectSealObject + err error + }{ + { + name: "normal", + msg: MsgRejectSealObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: testObjectName, + }, + }, + { + name: "invalid address", + msg: MsgRejectSealObject{ + Operator: "invalid address", + BucketName: "1.1.1.1", + ObjectName: testObjectName, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid bucket name", + msg: MsgRejectSealObject{ + Operator: sample.RandAccAddressHex(), + BucketName: "1.1.1.1", + ObjectName: testObjectName, + }, + err: gnfderrors.ErrInvalidBucketName, + }, + { + name: "invalid object name", + msg: MsgRejectSealObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: "", + }, + err: gnfderrors.ErrInvalidObjectName, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgUpdateObjectInfo_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgUpdateObjectInfo + err error + }{ + { + name: "normal", + msg: MsgUpdateObjectInfo{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: testObjectName, + Visibility: VISIBILITY_TYPE_INHERIT, + }, + }, + { + name: "invalid address", + msg: MsgUpdateObjectInfo{ + Operator: "invalid address", + BucketName: testBucketName, + ObjectName: testObjectName, + Visibility: VISIBILITY_TYPE_INHERIT, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid bucket name", + msg: MsgUpdateObjectInfo{ + Operator: sample.RandAccAddressHex(), + BucketName: "1.1.1.1", + ObjectName: testObjectName, + Visibility: VISIBILITY_TYPE_INHERIT, + }, + err: gnfderrors.ErrInvalidBucketName, + }, + { + name: "invalid bucket name", + msg: MsgUpdateObjectInfo{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: "", + Visibility: VISIBILITY_TYPE_INHERIT, + }, + err: gnfderrors.ErrInvalidObjectName, + }, + { + name: "invalid visibility", + msg: MsgUpdateObjectInfo{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectName: testObjectName, + Visibility: VISIBILITY_TYPE_UNSPECIFIED, + }, + err: ErrInvalidVisibility, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgMirrorObject_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgMirrorObject + err error + }{ + { + name: "normal", + msg: MsgMirrorObject{ + Operator: sample.RandAccAddressHex(), + Id: math.NewUint(1), + }, + }, + { + name: "invalid address", + msg: MsgMirrorObject{ + Operator: "wrong address", + Id: math.NewUint(1), + }, + err: sdkerrors.ErrInvalidAddress, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgDiscontinueObject_ValidateBasic(t *testing.T) { + invalidObjectIds := [MaxDiscontinueObjects + 1]Uint{} + tests := []struct { + name string + msg MsgDiscontinueObject + err error + }{ + { + name: "normal", + msg: MsgDiscontinueObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectIds: []Uint{math.NewUint(1)}, + Reason: "valid reason", + }, + }, + { + name: "invalid address", + msg: MsgDiscontinueObject{ + Operator: "invalid address", + BucketName: testBucketName, + ObjectIds: []Uint{math.NewUint(1)}, + Reason: "valid reason", + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid bucket name", + msg: MsgDiscontinueObject{ + Operator: sample.RandAccAddressHex(), + BucketName: "1.11.1.1", + ObjectIds: []Uint{math.NewUint(1)}, + Reason: "valid reason", + }, + err: gnfderrors.ErrInvalidBucketName, + }, + { + name: "invalid object ids", + msg: MsgDiscontinueObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectIds: nil, + Reason: "valid reason", + }, + err: ErrInvalidObjectIds, + }, + { + name: "invalid object ids", + msg: MsgDiscontinueObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectIds: invalidObjectIds[:], + Reason: "valid reason", + }, + err: ErrInvalidObjectIds, + }, + { + name: "invalid reason", + msg: MsgDiscontinueObject{ + Operator: sample.RandAccAddressHex(), + BucketName: testBucketName, + ObjectIds: []Uint{math.NewUint(1)}, + Reason: strings.Repeat("s", MaxDiscontinueReasonLen+1), + }, + err: ErrInvalidReason, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} diff --git a/x/storage/types/message_test.go b/x/storage/types/message_test.go index bb331524e..f3343a3bf 100644 --- a/x/storage/types/message_test.go +++ b/x/storage/types/message_test.go @@ -7,7 +7,6 @@ import ( "cosmossdk.io/math" sdk "github.com/cosmos/cosmos-sdk/types" sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" - "github.com/prysmaticlabs/prysm/crypto/bls" "github.com/stretchr/testify/require" "github.com/bnb-chain/greenfield/testutil/sample" @@ -33,54 +32,54 @@ func TestMsgCreateBucket_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgCreateBucket{ - Creator: sample.AccAddress(), + Creator: sample.RandAccAddressHex(), BucketName: testBucketName, Visibility: VISIBILITY_TYPE_PUBLIC_READ, - PaymentAddress: sample.AccAddress(), - PrimarySpAddress: sample.AccAddress(), + PaymentAddress: sample.RandAccAddressHex(), + PrimarySpAddress: sample.RandAccAddressHex(), PrimarySpApproval: &common.Approval{}, }, }, { name: "invalid bucket name", msg: MsgCreateBucket{ - Creator: sample.AccAddress(), + Creator: sample.RandAccAddressHex(), BucketName: "TestBucket", Visibility: VISIBILITY_TYPE_PUBLIC_READ, - PaymentAddress: sample.AccAddress(), - PrimarySpAddress: sample.AccAddress(), + PaymentAddress: sample.RandAccAddressHex(), + PrimarySpAddress: sample.RandAccAddressHex(), PrimarySpApproval: &common.Approval{}, }, err: gnfderrors.ErrInvalidBucketName, }, { name: "invalid bucket name", msg: MsgCreateBucket{ - Creator: sample.AccAddress(), + Creator: sample.RandAccAddressHex(), BucketName: "Test-Bucket", Visibility: VISIBILITY_TYPE_PUBLIC_READ, - PaymentAddress: sample.AccAddress(), - PrimarySpAddress: sample.AccAddress(), + PaymentAddress: sample.RandAccAddressHex(), + PrimarySpAddress: sample.RandAccAddressHex(), PrimarySpApproval: &common.Approval{}, }, err: gnfderrors.ErrInvalidBucketName, }, { name: "invalid bucket name", msg: MsgCreateBucket{ - Creator: sample.AccAddress(), + Creator: sample.RandAccAddressHex(), BucketName: "ss", Visibility: VISIBILITY_TYPE_PUBLIC_READ, - PaymentAddress: sample.AccAddress(), - PrimarySpAddress: sample.AccAddress(), + PaymentAddress: sample.RandAccAddressHex(), + PrimarySpAddress: sample.RandAccAddressHex(), PrimarySpApproval: &common.Approval{}, }, err: gnfderrors.ErrInvalidBucketName, }, { name: "invalid bucket name", msg: MsgCreateBucket{ - Creator: sample.AccAddress(), + Creator: sample.RandAccAddressHex(), BucketName: string(testInvalidBucketNameWithLongLength[:]), Visibility: VISIBILITY_TYPE_PUBLIC_READ, - PaymentAddress: sample.AccAddress(), - PrimarySpAddress: sample.AccAddress(), + PaymentAddress: sample.RandAccAddressHex(), + PrimarySpAddress: sample.RandAccAddressHex(), PrimarySpApproval: &common.Approval{}, }, err: gnfderrors.ErrInvalidBucketName, @@ -107,13 +106,13 @@ func TestMsgDeleteBucket_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgDeleteBucket{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), BucketName: testBucketName, }, }, { name: "invalid bucket name", msg: MsgDeleteBucket{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), BucketName: "testBucket", }, err: gnfderrors.ErrInvalidBucketName, @@ -140,9 +139,9 @@ func TestMsgUpdateBucketInfo_ValidateBasic(t *testing.T) { { name: "basic", msg: MsgUpdateBucketInfo{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), BucketName: testBucketName, - PaymentAddress: sample.AccAddress(), + PaymentAddress: sample.RandAccAddressHex(), ChargedReadQuota: &common.UInt64Value{Value: 10000}, }, }, @@ -159,261 +158,6 @@ func TestMsgUpdateBucketInfo_ValidateBasic(t *testing.T) { } } -func TestMsgCreateObject_ValidateBasic(t *testing.T) { - tests := []struct { - name string - msg MsgCreateObject - err error - }{ - { - name: "normal", - msg: MsgCreateObject{ - Creator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: testObjectName, - PayloadSize: 1024, - Visibility: VISIBILITY_TYPE_PRIVATE, - ContentType: "content-type", - PrimarySpApproval: &common.Approval{}, - ExpectChecksums: [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()}, - }, - }, { - name: "invalid object name", - msg: MsgCreateObject{ - Creator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: "", - PayloadSize: 1024, - Visibility: VISIBILITY_TYPE_PRIVATE, - ContentType: "content-type", - PrimarySpApproval: &common.Approval{}, - ExpectChecksums: [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()}, - }, - err: gnfderrors.ErrInvalidObjectName, - }, { - name: "invalid object name", - msg: MsgCreateObject{ - Creator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: "../object", - PayloadSize: 1024, - Visibility: VISIBILITY_TYPE_PRIVATE, - ContentType: "content-type", - PrimarySpApproval: &common.Approval{}, - ExpectChecksums: [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()}, - }, - err: gnfderrors.ErrInvalidObjectName, - }, { - name: "invalid object name", - msg: MsgCreateObject{ - Creator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: "//object", - PayloadSize: 1024, - Visibility: VISIBILITY_TYPE_PRIVATE, - ContentType: "content-type", - PrimarySpApproval: &common.Approval{}, - ExpectChecksums: [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()}, - }, - err: gnfderrors.ErrInvalidObjectName, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) - return - } - require.NoError(t, err) - }) - } -} - -func TestMsgCancelCreateObject_ValidateBasic(t *testing.T) { - tests := []struct { - name string - msg MsgCancelCreateObject - err error - }{ - { - name: "basic", - msg: MsgCancelCreateObject{ - Operator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: testObjectName, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) - return - } - require.NoError(t, err) - }) - } -} - -func TestMsgDeleteObject_ValidateBasic(t *testing.T) { - tests := []struct { - name string - msg MsgDeleteObject - err error - }{ - { - name: "normal", - msg: MsgDeleteObject{ - Operator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: testObjectName, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) - return - } - require.NoError(t, err) - }) - } -} - -func TestMsgCopyObject_ValidateBasic(t *testing.T) { - tests := []struct { - name string - msg MsgCopyObject - err error - }{ - { - name: "valid address", - msg: MsgCopyObject{ - Operator: sample.AccAddress(), - SrcBucketName: testBucketName, - SrcObjectName: testObjectName, - DstBucketName: "dst" + testBucketName, - DstObjectName: "dst" + testObjectName, - DstPrimarySpApproval: &common.Approval{ - ExpiredHeight: 100, - Sig: []byte("xxx"), - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) - return - } - require.NoError(t, err) - }) - } -} - -func TestMsgSealObject_ValidateBasic(t *testing.T) { - checksums := [][]byte{sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum(), sample.Checksum()} - blsSignDoc := NewSecondarySpSealObjectSignDoc("greenfield_9000-1", 1, math.NewUint(1), GenerateHash(checksums[:])).GetSignBytes() - blsPrivKey, _ := bls.RandKey() - aggSig := blsPrivKey.Sign(blsSignDoc[:]).Marshal() - tests := []struct { - name string - msg MsgSealObject - err error - }{ - { - name: "normal", - msg: MsgSealObject{ - Operator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: testObjectName, - SecondarySpBlsAggSignatures: aggSig, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) - return - } - require.NoError(t, err) - }) - } -} - -func TestMsgRejectSealObject_ValidateBasic(t *testing.T) { - tests := []struct { - name string - msg MsgRejectSealObject - err error - }{ - { - name: "normal", - msg: MsgRejectSealObject{ - Operator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: testObjectName, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) - return - } - require.NoError(t, err) - }) - } -} - -func TestMsgUpdateObjectInfo_ValidateBasic(t *testing.T) { - tests := []struct { - name string - msg MsgUpdateObjectInfo - err error - }{ - { - name: "normal", - msg: MsgUpdateObjectInfo{ - Operator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: testObjectName, - Visibility: VISIBILITY_TYPE_INHERIT, - }, - }, - { - name: "abnormal", - msg: MsgUpdateObjectInfo{ - Operator: sample.AccAddress(), - BucketName: testBucketName, - ObjectName: testObjectName, - Visibility: VISIBILITY_TYPE_UNSPECIFIED, - }, - err: ErrInvalidVisibility, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) - return - } - require.NoError(t, err) - }) - } -} - func TestMsgCreateGroup_ValidateBasic(t *testing.T) { tests := []struct { name string @@ -423,9 +167,9 @@ func TestMsgCreateGroup_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgCreateGroup{ - Creator: sample.AccAddress(), + Creator: sample.RandAccAddressHex(), GroupName: testGroupName, - Members: []string{sample.AccAddress(), sample.AccAddress()}, + Members: []string{sample.RandAccAddressHex(), sample.RandAccAddressHex()}, }, }, } @@ -450,7 +194,7 @@ func TestMsgDeleteGroup_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgDeleteGroup{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), GroupName: testGroupName, }, }, @@ -476,8 +220,8 @@ func TestMsgLeaveGroup_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgLeaveGroup{ - Member: sample.AccAddress(), - GroupOwner: sample.AccAddress(), + Member: sample.RandAccAddressHex(), + GroupOwner: sample.RandAccAddressHex(), GroupName: testGroupName, }, }, @@ -503,11 +247,11 @@ func TestMsgUpdateGroupMember_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgUpdateGroupMember{ - Operator: sample.AccAddress(), - GroupOwner: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), + GroupOwner: sample.RandAccAddressHex(), GroupName: testGroupName, - MembersToAdd: []string{sample.AccAddress(), sample.AccAddress()}, - MembersToDelete: []string{sample.AccAddress(), sample.AccAddress()}, + MembersToAdd: []string{sample.RandAccAddressHex(), sample.RandAccAddressHex()}, + MembersToDelete: []string{sample.RandAccAddressHex(), sample.RandAccAddressHex()}, }, }, } @@ -533,8 +277,8 @@ func TestMsgUpdateGroupExtra_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgUpdateGroupExtra{ - Operator: sample.AccAddress(), - GroupOwner: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), + GroupOwner: sample.RandAccAddressHex(), GroupName: testGroupName, Extra: "testExtra", }, @@ -542,8 +286,8 @@ func TestMsgUpdateGroupExtra_ValidateBasic(t *testing.T) { { name: "extra field is too long", msg: MsgUpdateGroupExtra{ - Operator: sample.AccAddress(), - GroupOwner: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), + GroupOwner: sample.RandAccAddressHex(), GroupName: testGroupName, Extra: strings.Repeat("abcdefg", 80), }, @@ -571,9 +315,9 @@ func TestMsgPutPolicy_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgPutPolicy{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), Resource: types2.NewBucketGRN(testBucketName).String(), - Principal: types.NewPrincipalWithAccount(sdk.MustAccAddressFromHex(sample.AccAddress())), + Principal: types.NewPrincipalWithAccount(sdk.MustAccAddressFromHex(sample.RandAccAddressHex())), Statements: []*types.Statement{{Effect: types.EFFECT_ALLOW, Actions: []types.ActionType{types.ACTION_DELETE_BUCKET}}}, }, @@ -600,9 +344,9 @@ func TestMsgDeletePolicy_ValidateBasic(t *testing.T) { { name: "valid address", msg: MsgDeletePolicy{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), Resource: types2.NewBucketGRN(testBucketName).String(), - Principal: types.NewPrincipalWithAccount(sdk.MustAccAddressFromHex(sample.AccAddress())), + Principal: types.NewPrincipalWithAccount(sdk.MustAccAddressFromHex(sample.RandAccAddressHex())), }, }, } @@ -627,7 +371,7 @@ func TestMsgMirrorBucket_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgMirrorBucket{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), Id: math.NewUint(1), }, }, { @@ -651,40 +395,6 @@ func TestMsgMirrorBucket_ValidateBasic(t *testing.T) { } } -func TestMsgMirrorObject_ValidateBasic(t *testing.T) { - tests := []struct { - name string - msg MsgMirrorObject - err error - }{ - { - name: "normal", - msg: MsgMirrorObject{ - Operator: sample.AccAddress(), - Id: math.NewUint(1), - }, - }, - { - name: "invalid address", - msg: MsgMirrorObject{ - Operator: "wrong address", - Id: math.NewUint(1), - }, - err: sdkerrors.ErrInvalidAddress, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := tt.msg.ValidateBasic() - if tt.err != nil { - require.ErrorIs(t, err, tt.err) - return - } - require.NoError(t, err) - }) - } -} - func TestMsgMirrorGroup_ValidateBasic(t *testing.T) { tests := []struct { name string @@ -694,7 +404,7 @@ func TestMsgMirrorGroup_ValidateBasic(t *testing.T) { { name: "normal", msg: MsgMirrorGroup{ - Operator: sample.AccAddress(), + Operator: sample.RandAccAddressHex(), Id: math.NewUint(1), }, }, diff --git a/x/storage/keeper/options.go b/x/storage/types/options.go similarity index 59% rename from x/storage/keeper/options.go rename to x/storage/types/options.go index 5cf66a102..dc36988e8 100644 --- a/x/storage/keeper/options.go +++ b/x/storage/types/options.go @@ -1,13 +1,12 @@ -package keeper +package types import ( "github.com/bnb-chain/greenfield/types/common" - "github.com/bnb-chain/greenfield/x/storage/types" ) type CreateBucketOptions struct { - Visibility types.VisibilityType - SourceType types.SourceType + Visibility VisibilityType + SourceType SourceType ChargedReadQuota uint64 PaymentAddress string PrimarySpApproval *common.Approval @@ -15,51 +14,55 @@ type CreateBucketOptions struct { } type DeleteBucketOptions struct { - SourceType types.SourceType + SourceType SourceType } type UpdateBucketOptions struct { - Visibility types.VisibilityType - SourceType types.SourceType + Visibility VisibilityType + SourceType SourceType PaymentAddress string ChargedReadQuota *uint64 } type CreateObjectOptions struct { - Visibility types.VisibilityType + Visibility VisibilityType ContentType string - SourceType types.SourceType - RedundancyType types.RedundancyType + SourceType SourceType + RedundancyType RedundancyType Checksums [][]byte PrimarySpApproval *common.Approval ApprovalMsgBytes []byte } type CancelCreateObjectOptions struct { - SourceType types.SourceType + SourceType SourceType } type DeleteObjectOptions struct { - SourceType types.SourceType + SourceType SourceType } type CopyObjectOptions struct { - SourceType types.SourceType - Visibility types.VisibilityType + SourceType SourceType + Visibility VisibilityType PrimarySpApproval *common.Approval ApprovalMsgBytes []byte } type CreateGroupOptions struct { Members []string - SourceType types.SourceType + SourceType SourceType Extra string } type LeaveGroupOptions struct { - SourceType types.SourceType + SourceType SourceType } type UpdateGroupMemberOptions struct { - SourceType types.SourceType + SourceType SourceType MembersToAdd []string MembersToDelete []string } + +type DeleteGroupOptions struct { + SourceType SourceType +} diff --git a/x/virtualgroup/keeper/grpc_query.go b/x/virtualgroup/keeper/grpc_query.go index b31405cfb..2fe684529 100644 --- a/x/virtualgroup/keeper/grpc_query.go +++ b/x/virtualgroup/keeper/grpc_query.go @@ -4,12 +4,13 @@ import ( "context" "math" - "github.com/bnb-chain/greenfield/x/virtualgroup/types" "github.com/cosmos/cosmos-sdk/store/prefix" sdk "github.com/cosmos/cosmos-sdk/types" "github.com/cosmos/cosmos-sdk/types/query" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + + "github.com/bnb-chain/greenfield/x/virtualgroup/types" ) func (k Keeper) Params(c context.Context, req *types.QueryParamsRequest) (*types.QueryParamsResponse, error) { @@ -28,10 +29,6 @@ func (k Keeper) GlobalVirtualGroup(goCtx context.Context, req *types.QueryGlobal ctx := sdk.UnwrapSDKContext(goCtx) - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") - } - gvg, found := k.GetGVG(ctx, req.GlobalVirtualGroupId) if !found { return nil, types.ErrGVGNotExist @@ -70,10 +67,6 @@ func (k Keeper) GlobalVirtualGroupFamily(goCtx context.Context, req *types.Query ctx := sdk.UnwrapSDKContext(goCtx) - if req == nil { - return nil, status.Error(codes.InvalidArgument, "invalid request") - } - gvgFamily, found := k.GetGVGFamily(ctx, req.FamilyId) if !found { return nil, types.ErrGVGFamilyNotExist diff --git a/x/virtualgroup/module_simulation.go b/x/virtualgroup/module_simulation.go index 44168c0aa..1862eda20 100644 --- a/x/virtualgroup/module_simulation.go +++ b/x/virtualgroup/module_simulation.go @@ -16,7 +16,7 @@ import ( // avoid unused import issue var ( - _ = sample.AccAddress + _ = sample.RandAccAddressHex _ = virtualgroupsimulation.FindAccount _ = simulation.MsgEntryKind _ = baseapp.Paramspace diff --git a/x/virtualgroup/types/message.go b/x/virtualgroup/types/message.go index dac1e3184..8bc75b594 100644 --- a/x/virtualgroup/types/message.go +++ b/x/virtualgroup/types/message.go @@ -64,7 +64,7 @@ func (msg *MsgCreateGlobalVirtualGroup) GetSigners() []sdk.AccAddress { func (msg *MsgCreateGlobalVirtualGroup) ValidateBasic() error { _, err := sdk.AccAddressFromHexUnsafe(msg.StorageProvider) if err != nil { - return err + return sdkerrors.ErrInvalidAddress.Wrapf("invalid storage provider address (%s)", err) } if !msg.Deposit.IsValid() || !msg.Deposit.Amount.IsPositive() { @@ -107,7 +107,7 @@ func (msg *MsgDeleteGlobalVirtualGroup) GetSigners() []sdk.AccAddress { func (msg *MsgDeleteGlobalVirtualGroup) ValidateBasic() error { _, err := sdk.AccAddressFromHexUnsafe(msg.StorageProvider) if err != nil { - return err + return sdkerrors.ErrInvalidAddress.Wrapf("invalid storage provider address (%s)", err) } return nil @@ -147,7 +147,7 @@ func (msg *MsgDeposit) GetSigners() []sdk.AccAddress { func (msg *MsgDeposit) ValidateBasic() error { _, err := sdk.AccAddressFromHexUnsafe(msg.StorageProvider) if err != nil { - return err + return sdkerrors.ErrInvalidAddress.Wrapf("invalid storage provider address (%s)", err) } if !msg.Deposit.IsValid() || !msg.Deposit.Amount.IsPositive() { @@ -188,7 +188,7 @@ func (msg *MsgWithdraw) GetSigners() []sdk.AccAddress { func (msg *MsgWithdraw) ValidateBasic() error { _, err := sdk.AccAddressFromHexUnsafe(msg.StorageProvider) if err != nil { - return err + return sdkerrors.ErrInvalidAddress.Wrapf("invalid storage provider address (%s)", err) } if !msg.Withdraw.IsValid() || !msg.Withdraw.Amount.IsPositive() { @@ -235,7 +235,7 @@ func (msg *MsgSwapOut) GetSigners() []sdk.AccAddress { func (msg *MsgSwapOut) ValidateBasic() error { _, err := sdk.AccAddressFromHexUnsafe(msg.StorageProvider) if err != nil { - return err + return sdkerrors.ErrInvalidAddress.Wrapf("invalid storage provider address (%s)", err) } if msg.GlobalVirtualGroupFamilyId == NoSpecifiedFamilyId { @@ -249,7 +249,11 @@ func (msg *MsgSwapOut) ValidateBasic() error { } if msg.SuccessorSpId == 0 { - return gnfderrors.ErrInvalidMessage.Wrap("The successor sp id is not specify.") + return gnfderrors.ErrInvalidMessage.Wrap("The successor sp id is not specified.") + } + + if msg.SuccessorSpApproval == nil { + return gnfderrors.ErrInvalidMessage.Wrap("The successor sp approval is not specified.") } return nil @@ -310,7 +314,7 @@ func (msg *MsgSettle) GetSigners() []sdk.AccAddress { func (msg *MsgSettle) ValidateBasic() error { _, err := sdk.AccAddressFromHexUnsafe(msg.StorageProvider) if err != nil { - return err + return sdkerrors.ErrInvalidAddress.Wrapf("invalid storage provider address (%s)", err) } if msg.GlobalVirtualGroupFamilyId == NoSpecifiedFamilyId { diff --git a/x/virtualgroup/types/message_cancel_swap_out_test.go b/x/virtualgroup/types/message_cancel_swap_out_test.go index 5aa4d79f6..ee657ba38 100644 --- a/x/virtualgroup/types/message_cancel_swap_out_test.go +++ b/x/virtualgroup/types/message_cancel_swap_out_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/bnb-chain/greenfield/testutil/sample" + gnfderrors "github.com/bnb-chain/greenfield/types/errors" ) func TestMsgCancelSwapOut_ValidateBasic(t *testing.T) { @@ -15,6 +16,14 @@ func TestMsgCancelSwapOut_ValidateBasic(t *testing.T) { msg MsgCancelSwapOut err error }{ + { + name: "valid address", + msg: *NewMsgCancelSwapOut( + sample.RandAccAddress(), + 1, + []uint32{}, + ), + }, { name: "invalid address", msg: MsgCancelSwapOut{ @@ -22,12 +31,24 @@ func TestMsgCancelSwapOut_ValidateBasic(t *testing.T) { GlobalVirtualGroupFamilyId: 1, }, err: sdkerrors.ErrInvalidAddress, - }, { - name: "valid address", + }, + { + name: "invalid gvg groups", msg: MsgCancelSwapOut{ - StorageProvider: sample.AccAddress(), + StorageProvider: sample.RandAccAddressHex(), GlobalVirtualGroupFamilyId: 1, + GlobalVirtualGroupIds: []uint32{1, 2}, + }, + err: gnfderrors.ErrInvalidMessage, + }, + { + name: "invalid gvg groups", + msg: MsgCancelSwapOut{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 0, + GlobalVirtualGroupIds: []uint32{}, }, + err: gnfderrors.ErrInvalidMessage, }, } for _, tt := range tests { diff --git a/x/virtualgroup/types/message_complete_storage_provider_exit_test.go b/x/virtualgroup/types/message_complete_storage_provider_exit_test.go index 97b94fedc..b9948939f 100644 --- a/x/virtualgroup/types/message_complete_storage_provider_exit_test.go +++ b/x/virtualgroup/types/message_complete_storage_provider_exit_test.go @@ -23,9 +23,7 @@ func TestMsgCompleteStorageProviderExit_ValidateBasic(t *testing.T) { err: sdkerrors.ErrInvalidAddress, }, { name: "valid address", - msg: MsgCompleteStorageProviderExit{ - StorageProvider: sample.AccAddress(), - }, + msg: *NewMsgCompleteStorageProviderExit(sample.RandAccAddress()), }, } for _, tt := range tests { diff --git a/x/virtualgroup/types/message_complete_swap_out_test.go b/x/virtualgroup/types/message_complete_swap_out_test.go index 7d17fcd8f..96a396836 100644 --- a/x/virtualgroup/types/message_complete_swap_out_test.go +++ b/x/virtualgroup/types/message_complete_swap_out_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/bnb-chain/greenfield/testutil/sample" + gnfderrors "github.com/bnb-chain/greenfield/types/errors" ) func TestMsgCompleteSwapOut_ValidateBasic(t *testing.T) { @@ -15,6 +16,14 @@ func TestMsgCompleteSwapOut_ValidateBasic(t *testing.T) { msg MsgCompleteSwapOut err error }{ + { + name: "valid address", + msg: *NewMsgCompleteSwapOut( + sample.RandAccAddress(), + 1, + []uint32{}, + ), + }, { name: "invalid address", msg: MsgCompleteSwapOut{ @@ -22,12 +31,24 @@ func TestMsgCompleteSwapOut_ValidateBasic(t *testing.T) { GlobalVirtualGroupFamilyId: 1, }, err: sdkerrors.ErrInvalidAddress, - }, { - name: "valid address", + }, + { + name: "invalid gvg groups", msg: MsgCompleteSwapOut{ - StorageProvider: sample.AccAddress(), + StorageProvider: sample.RandAccAddressHex(), GlobalVirtualGroupFamilyId: 1, + GlobalVirtualGroupIds: []uint32{1, 2, 3}, + }, + err: gnfderrors.ErrInvalidMessage, + }, + { + name: "invalid gvg groups", + msg: MsgCompleteSwapOut{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 0, + GlobalVirtualGroupIds: []uint32{}, }, + err: gnfderrors.ErrInvalidMessage, }, } for _, tt := range tests { diff --git a/x/virtualgroup/types/message_storage_provider_exit_test.go b/x/virtualgroup/types/message_storage_provider_exit_test.go index 49e2e2b57..e849077ee 100644 --- a/x/virtualgroup/types/message_storage_provider_exit_test.go +++ b/x/virtualgroup/types/message_storage_provider_exit_test.go @@ -23,9 +23,7 @@ func TestMsgStorageProviderExit_ValidateBasic(t *testing.T) { err: sdkerrors.ErrInvalidAddress, }, { name: "valid address", - msg: MsgStorageProviderExit{ - StorageProvider: sample.AccAddress(), - }, + msg: *NewMsgStorageProviderExit(sample.RandAccAddress()), }, } for _, tt := range tests { diff --git a/x/virtualgroup/types/message_test.go b/x/virtualgroup/types/message_test.go new file mode 100644 index 000000000..8e77991dc --- /dev/null +++ b/x/virtualgroup/types/message_test.go @@ -0,0 +1,381 @@ +package types + +import ( + "testing" + + "github.com/cosmos/cosmos-sdk/types" + sdkerrors "github.com/cosmos/cosmos-sdk/types/errors" + "github.com/stretchr/testify/require" + + "github.com/bnb-chain/greenfield/testutil/sample" + "github.com/bnb-chain/greenfield/types/common" + gnfderrors "github.com/bnb-chain/greenfield/types/errors" +) + +func TestMsgDeposit_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgDeposit + err error + }{ + { + name: "invalid address", + msg: MsgDeposit{ + StorageProvider: "invalid_address", + GlobalVirtualGroupId: 1, + Deposit: types.Coin{ + Denom: "denom", + Amount: types.NewInt(1), + }, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid deposit amount", + msg: MsgDeposit{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupId: 1, + Deposit: types.Coin{ + Denom: "denom", + Amount: types.NewInt(0), + }, + }, + err: sdkerrors.ErrInvalidRequest, + }, + { + name: "valid case", + msg: *NewMsgDeposit( + sample.RandAccAddress(), + 1, + types.Coin{ + Denom: "denom", + Amount: types.NewInt(1), + }, + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgWithdraw_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgWithdraw + err error + }{ + { + name: "invalid address", + msg: MsgWithdraw{ + StorageProvider: "invalid_address", + GlobalVirtualGroupId: 1, + Withdraw: types.Coin{ + Denom: "denom", + Amount: types.NewInt(1), + }, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid deposit amount", + msg: MsgWithdraw{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupId: 1, + Withdraw: types.Coin{ + Denom: "denom", + Amount: types.NewInt(0), + }, + }, + err: sdkerrors.ErrInvalidRequest, + }, + { + name: "valid case", + msg: *NewMsgWithdraw( + sample.RandAccAddress(), + 1, + types.Coin{ + Denom: "denom", + Amount: types.NewInt(1), + }, + ), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgSwapOut_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgSwapOut + err error + }{ + { + name: "valid case", + msg: MsgSwapOut{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 0, + GlobalVirtualGroupIds: []uint32{1, 2, 3}, + SuccessorSpId: 100, + SuccessorSpApproval: &common.Approval{ + ExpiredHeight: 100, + GlobalVirtualGroupFamilyId: 1, + Sig: []byte("sig"), + }, + }, + }, + { + name: "valid case", + msg: MsgSwapOut{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 1, + GlobalVirtualGroupIds: []uint32{}, + SuccessorSpId: 100, + SuccessorSpApproval: &common.Approval{ + ExpiredHeight: 100, + GlobalVirtualGroupFamilyId: 1, + Sig: []byte("sig"), + }, + }, + }, + { + name: "invalid address", + msg: MsgSwapOut{ + StorageProvider: "invalid address", + GlobalVirtualGroupFamilyId: 1, + GlobalVirtualGroupIds: []uint32{}, + SuccessorSpId: 100, + SuccessorSpApproval: &common.Approval{ + ExpiredHeight: 100, + GlobalVirtualGroupFamilyId: 1, + Sig: []byte("sig"), + }, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid virtual group family", + msg: MsgSwapOut{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 1, + GlobalVirtualGroupIds: []uint32{1}, + SuccessorSpId: 100, + SuccessorSpApproval: &common.Approval{ + ExpiredHeight: 100, + GlobalVirtualGroupFamilyId: 1, + Sig: []byte("sig"), + }, + }, + err: gnfderrors.ErrInvalidMessage, + }, + { + name: "invalid virtual group family", + msg: MsgSwapOut{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 0, + GlobalVirtualGroupIds: []uint32{}, + SuccessorSpId: 100, + SuccessorSpApproval: &common.Approval{ + ExpiredHeight: 100, + GlobalVirtualGroupFamilyId: 1, + Sig: []byte("sig"), + }, + }, + err: gnfderrors.ErrInvalidMessage, + }, + { + name: "invalid successor sp id", + msg: MsgSwapOut{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 1, + GlobalVirtualGroupIds: []uint32{}, + SuccessorSpId: 0, + SuccessorSpApproval: &common.Approval{ + ExpiredHeight: 100, + GlobalVirtualGroupFamilyId: 1, + Sig: []byte("sig"), + }, + }, + err: gnfderrors.ErrInvalidMessage, + }, + { + name: "invalid successor sp approval", + msg: MsgSwapOut{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 1, + GlobalVirtualGroupIds: []uint32{}, + SuccessorSpId: 1, + }, + err: gnfderrors.ErrInvalidMessage, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgCreateGlobalVirtualGroup_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgCreateGlobalVirtualGroup + err error + }{ + { + name: "valid case", + msg: *NewMsgCreateGlobalVirtualGroup( + sample.RandAccAddress(), + 1, + []uint32{2, 3, 4}, + types.Coin{ + Denom: "denom", + Amount: types.NewInt(1), + }, + ), + }, + { + name: "invalid address", + msg: MsgCreateGlobalVirtualGroup{ + StorageProvider: "invalid_address", + FamilyId: 1, + SecondarySpIds: []uint32{2, 3, 4}, + Deposit: types.Coin{ + Denom: "denom", + Amount: types.NewInt(1), + }, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid deposit coin", + msg: MsgCreateGlobalVirtualGroup{ + StorageProvider: "invalid_address", + FamilyId: 1, + SecondarySpIds: []uint32{2, 3, 4}, + Deposit: types.Coin{ + Denom: "denom", + Amount: types.NewInt(0), + }, + }, + err: sdkerrors.ErrInvalidAddress, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgDeleteGlobalVirtualGroup_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgDeleteGlobalVirtualGroup + err error + }{ + { + name: "valid case", + msg: *NewMsgDeleteGlobalVirtualGroup( + sample.RandAccAddress(), + 1, + ), + }, + { + name: "invalid address", + msg: MsgDeleteGlobalVirtualGroup{ + StorageProvider: "invalid_address", + GlobalVirtualGroupId: 1, + }, + err: sdkerrors.ErrInvalidAddress, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMsgSettle_ValidateBasic(t *testing.T) { + tests := []struct { + name string + msg MsgSettle + err error + }{ + { + name: "valid case", + msg: *NewMsgSettle( + sample.RandAccAddress(), + 1, + []uint32{1, 2, 3, 4}, + ), + }, + { + name: "invalid address", + msg: MsgSettle{ + StorageProvider: "invalid_address", + GlobalVirtualGroupFamilyId: 1, + }, + err: sdkerrors.ErrInvalidAddress, + }, + { + name: "invalid gvg ids", + msg: MsgSettle{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 0, + }, + err: ErrInvalidGVGCount, + }, + { + name: "invalid gvg ids", + msg: MsgSettle{ + StorageProvider: sample.RandAccAddressHex(), + GlobalVirtualGroupFamilyId: 0, + GlobalVirtualGroupIds: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + }, + err: ErrInvalidGVGCount, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.msg.ValidateBasic() + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} diff --git a/x/virtualgroup/types/params.go b/x/virtualgroup/types/params.go index 4ee62fa37..f9962ffbf 100644 --- a/x/virtualgroup/types/params.go +++ b/x/virtualgroup/types/params.go @@ -120,7 +120,7 @@ func validateMaxStoreSizePerFamily(i interface{}) error { } if v == 0 { - return fmt.Errorf("max buckets per account must be positive: %d", v) + return fmt.Errorf("max store size of family must be positive: %d", v) } return nil diff --git a/x/virtualgroup/types/params_test.go b/x/virtualgroup/types/params_test.go new file mode 100644 index 000000000..da5c08091 --- /dev/null +++ b/x/virtualgroup/types/params_test.go @@ -0,0 +1,154 @@ +package types + +import ( + "testing" + + sdk "github.com/cosmos/cosmos-sdk/types" + "github.com/stretchr/testify/require" +) + +func TestDepositDenom(t *testing.T) { + tests := []struct { + name string + denom interface{} + err string + }{ + + { + name: "valid", + denom: "denom", + }, + { + name: "invalid type", + denom: 1, + err: "invalid parameter type", + }, + { + name: "empty", + denom: " ", + err: "deposit denom cannot be blank", + }, + { + name: "invalid denom", + denom: "%", + err: "invalid denom", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateDepositDenom(tt.denom) + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestGVGStakingPerBytes(t *testing.T) { + tests := []struct { + name string + ratio interface{} + err string + }{ + + { + name: "valid", + ratio: sdk.NewDec(1), + }, + { + name: "invalid type", + ratio: 1, + err: "invalid parameter type", + }, + { + name: "invalid ratio", + ratio: sdk.NewDec(100), + err: "invalid secondary sp store price ratio", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateGVGStakingPerBytes(tt.ratio) + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMaxGlobalVirtualGroupNumPerFamily(t *testing.T) { + tests := []struct { + name string + number interface{} + err string + }{ + + { + name: "valid", + number: uint32(1), + }, + { + name: "invalid type", + number: 1, + err: "invalid parameter type", + }, + { + name: "invalid size", + number: uint32(0), + err: "max buckets per account must be positive", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateMaxGlobalVirtualGroupNumPerFamily(tt.number) + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestMaxStoreSizePerFamily(t *testing.T) { + tests := []struct { + name string + size interface{} + err string + }{ + + { + name: "valid", + size: uint64(1), + }, + { + name: "invalid type", + size: 1, + err: "invalid parameter type", + }, + { + name: "invalid size", + size: uint64(0), + err: "max store size of family must be positive", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateMaxStoreSizePerFamily(tt.size) + if tt.err != "" { + require.ErrorContains(t, err, tt.err) + return + } + require.NoError(t, err) + }) + } +} + +func TestValidateParams(t *testing.T) { + err := DefaultParams().Validate() + require.NoError(t, err) +}