diff --git a/.github/workflows/v2-test.yml b/.github/workflows/v2-test.yml new file mode 100644 index 000000000000..33717512874f --- /dev/null +++ b/.github/workflows/v2-test.yml @@ -0,0 +1,39 @@ +name: v2 core Tests +on: + pull_request: + merge_group: + push: + branches: + - main + +permissions: + contents: read + +concurrency: + group: ci-${{ github.ref }}-v2-tests + cancel-in-progress: true + +jobs: + stf: + runs-on: ubuntu-latest + strategy: + fail-fast: false + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version: "1.22" + check-latest: true + cache: true + cache-dependency-path: go.sum + - uses: technote-space/get-diff-action@v6.1.2 + id: git_diff + with: + PATTERNS: | + server/v2/stf/**/*.go + server/v2/stf/go.mod + server/v2/stf/go.sum + - name: test & coverage report creation + if: env.GIT_DIFF + run: | + cd server/v2/stf && go test -mod=readonly -race -timeout 30m -covermode=atomic -tags='ledger test_ledger_mock' diff --git a/core/app/app.go b/core/app/app.go new file mode 100644 index 000000000000..055587d5d712 --- /dev/null +++ b/core/app/app.go @@ -0,0 +1,66 @@ +package app + +import ( + "time" + + appmodulev2 "cosmossdk.io/core/appmodule/v2" + "cosmossdk.io/core/event" + "cosmossdk.io/core/transaction" +) + +type QueryRequest struct { + Height int64 + Path string + Data []byte +} + +type QueryResponse struct { + Height int64 + Value []byte +} + +type BlockRequest[T any] struct { + Height uint64 + Time time.Time + Hash []byte + ChainId string + AppHash []byte + Txs []T + ConsensusMessages []transaction.Type +} + +type BlockResponse struct { + Apphash []byte + ConsensusMessagesResponse []transaction.Type + ValidatorUpdates []appmodulev2.ValidatorUpdate + PreBlockEvents []event.Event + BeginBlockEvents []event.Event + TxResults []TxResult + EndBlockEvents []event.Event +} + +type RequestInitChain struct { + Time time.Time + ChainId string + Validators []appmodulev2.ValidatorUpdate + AppStateBytes []byte + InitialHeight int64 +} + +type ResponseInitChain struct { + Validators []appmodulev2.ValidatorUpdate + AppHash []byte +} + +type TxResult struct { + Events []event.Event + Resp []transaction.Type + Error error + Code uint32 + Data []byte + Log string + Info string + GasWanted uint64 + GasUsed uint64 + Codespace string +} diff --git a/core/app/codec.go b/core/app/codec.go new file mode 100644 index 000000000000..5673020c7575 --- /dev/null +++ b/core/app/codec.go @@ -0,0 +1,21 @@ +package app + +import ( + "github.com/cosmos/gogoproto/jsonpb" + gogoproto "github.com/cosmos/gogoproto/proto" +) + +// MsgInterfaceProtoName defines the protobuf name of the cosmos Msg interface +const MsgInterfaceProtoName = "cosmos.base.v1beta1.Msg" + +type ProtoCodec interface { + Marshal(v gogoproto.Message) ([]byte, error) + Unmarshal(data []byte, v gogoproto.Message) error + Name() string +} + +type InterfaceRegistry interface { + jsonpb.AnyResolver + ListImplementations(ifaceTypeURL string) []string + ListAllInterfaces() []string +} diff --git a/core/app/identity.go b/core/app/identity.go new file mode 100644 index 000000000000..861135cd5a7b --- /dev/null +++ b/core/app/identity.go @@ -0,0 +1,6 @@ +package app + +var ( + RuntimeIdentity = []byte("runtime") + ConsensusIdentity = []byte("consensus") +) diff --git a/core/context/context.go b/core/context/context.go new file mode 100644 index 000000000000..4c3ab156dd8a --- /dev/null +++ b/core/context/context.go @@ -0,0 +1,16 @@ +package appmodule + +// ExecMode defines the execution mode which can be set on a Context. +type ExecMode uint8 + +// All possible execution modes. +const ( + ExecModeCheck ExecMode = iota + ExecModeReCheck + ExecModeSimulate + ExecModePrepareProposal + ExecModeProcessProposal + ExecModeVoteExtension + ExecModeVerifyVoteExtension + ExecModeFinalize +) diff --git a/core/gas/service.go b/core/gas/service.go index 2fb587d0d961..ba04d84eed55 100644 --- a/core/gas/service.go +++ b/core/gas/service.go @@ -46,8 +46,8 @@ type Service interface { // Meter represents a gas meter for modules consumption type Meter interface { - Consume(amount Gas, descriptor string) - Refund(amount Gas, descriptor string) + Consume(amount Gas, descriptor string) error + Refund(amount Gas, descriptor string) error Remaining() Gas Limit() Gas } diff --git a/core/header/service.go b/core/header/service.go index 2dcebc4f151c..b97089e129b5 100644 --- a/core/header/service.go +++ b/core/header/service.go @@ -2,6 +2,9 @@ package header import ( "context" + "crypto/sha256" + "encoding/binary" + "errors" "time" ) @@ -15,6 +18,74 @@ type Info struct { Height int64 // Height returns the height of the block Hash []byte // Hash returns the hash of the block header Time time.Time // Time returns the time of the block - ChainID string // ChainId returns the chain ID of the block AppHash []byte // AppHash used in the current block header + ChainID string // ChainId returns the chain ID of the block +} + +const hashSize = sha256.Size + +// Bytes encodes the Info struct into a byte slice using little-endian encoding +func (i *Info) Bytes() ([]byte, error) { + buf := make([]byte, 0) + + // Encode Height + heightBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(heightBytes, uint64(i.Height)) + buf = append(buf, heightBytes...) + + // Encode Hash + if len(i.Hash) != hashSize { + return nil, errors.New("invalid hash size") + } + buf = append(buf, i.Hash...) + + // Encode Time + timeBytes := make([]byte, 8) + binary.LittleEndian.PutUint64(timeBytes, uint64(i.Time.Unix())) + buf = append(buf, timeBytes...) + + // Encode AppHash + if len(i.Hash) != hashSize { + return nil, errors.New("invalid hash size") + } + buf = append(buf, i.AppHash...) + + // Encode ChainID + chainIDLen := len(i.ChainID) + buf = append(buf, byte(chainIDLen)) + buf = append(buf, []byte(i.ChainID)...) + + return buf, nil +} + +// FromBytes decodes the byte slice into an Info struct using little-endian encoding +func (i *Info) FromBytes(bytes []byte) error { + // Decode Height + i.Height = int64(binary.LittleEndian.Uint64(bytes[:8])) + bytes = bytes[8:] + + // Decode Hash + i.Hash = make([]byte, hashSize) + copy(i.Hash, bytes[:hashSize]) + bytes = bytes[hashSize:] + + // Decode Time + unixTime := int64(binary.LittleEndian.Uint64(bytes[:8])) + i.Time = time.Unix(unixTime, 0).UTC() + bytes = bytes[8:] + + // Decode AppHash + i.AppHash = make([]byte, hashSize) + copy(i.AppHash, bytes[:hashSize]) + bytes = bytes[hashSize:] + + // Decode ChainID + chainIDLen := int(bytes[0]) + bytes = bytes[1:] + if len(bytes) < chainIDLen { + return errors.New("invalid byte slice length") + } + i.ChainID = string(bytes[:chainIDLen]) + + return nil } diff --git a/core/header/service_test.go b/core/header/service_test.go new file mode 100644 index 000000000000..5aff568628d7 --- /dev/null +++ b/core/header/service_test.go @@ -0,0 +1,56 @@ +package header + +import ( + "crypto/sha256" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestInfo_Bytes(t *testing.T) { + sum := sha256.Sum256([]byte("test-chain")) + info := Info{ + Height: 12345, + Hash: sum[:], + Time: time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC), + AppHash: sum[:], + ChainID: "test-chain", + } + + expectedBytes := []byte{ + 0x39, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Height (little-endian) + 0x26, 0xb0, 0xb8, 0x3e, 0x72, 0x81, 0xbe, 0x3b, 0x11, 0x76, 0x58, 0xb6, 0xf2, 0x63, 0x6d, 0x3, 0x68, 0xca, 0xd3, 0xd7, 0x4f, 0x22, 0x24, 0x34, 0x28, 0xf5, 0x40, 0x1a, 0x4b, 0x70, 0x89, 0x7e, // Hash + 0x80, 0x0, 0x92, 0x65, 0x0, 0x0, 0x0, 0x0, // Time (little-endian) + 0x26, 0xb0, 0xb8, 0x3e, 0x72, 0x81, 0xbe, 0x3b, 0x11, 0x76, 0x58, 0xb6, 0xf2, 0x63, 0x6d, 0x3, 0x68, 0xca, 0xd3, 0xd7, 0x4f, 0x22, 0x24, 0x34, 0x28, 0xf5, 0x40, 0x1a, 0x4b, 0x70, 0x89, 0x7e, // Apphash + 0x0A, // ChainID length + 0x74, 0x65, 0x73, 0x74, 0x2d, 0x63, 0x68, 0x61, 0x69, 0x6e, // ChainID + } + + bytes, err := info.Bytes() + require.NoError(t, err) + require.Equal(t, expectedBytes, bytes) +} + +func TestInfo_FromBytes(t *testing.T) { + info := Info{} + + // Test case 1: Valid byte slice + bytes := []byte{ + 0x39, 0x30, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // Height (little-endian) + 0x26, 0xb0, 0xb8, 0x3e, 0x72, 0x81, 0xbe, 0x3b, 0x11, 0x76, 0x58, 0xb6, 0xf2, 0x63, 0x6d, 0x3, 0x68, 0xca, 0xd3, 0xd7, 0x4f, 0x22, 0x24, 0x34, 0x28, 0xf5, 0x40, 0x1a, 0x4b, 0x70, 0x89, 0x7e, // Hash + 0x80, 0x0, 0x92, 0x65, 0x0, 0x0, 0x0, 0x0, // Time (little-endian) + 0x26, 0xb0, 0xb8, 0x3e, 0x72, 0x81, 0xbe, 0x3b, 0x11, 0x76, 0x58, 0xb6, 0xf2, 0x63, 0x6d, 0x3, 0x68, 0xca, 0xd3, 0xd7, 0x4f, 0x22, 0x24, 0x34, 0x28, 0xf5, 0x40, 0x1a, 0x4b, 0x70, 0x89, 0x7e, // Apphash + 0x0A, // ChainID length + 0x74, 0x65, 0x73, 0x74, 0x2d, 0x63, 0x68, 0x61, 0x69, 0x6e, // ChainID + } + + err := info.FromBytes(bytes) + require.NoError(t, err) + require.Equal(t, int64(12345), info.Height) + require.Equal(t, []byte{0x26, 0xb0, 0xb8, 0x3e, 0x72, 0x81, 0xbe, 0x3b, 0x11, 0x76, 0x58, 0xb6, 0xf2, 0x63, 0x6d, 0x3, 0x68, 0xca, 0xd3, 0xd7, 0x4f, 0x22, 0x24, 0x34, 0x28, 0xf5, 0x40, 0x1a, 0x4b, 0x70, 0x89, 0x7e}, info.Hash) + require.Equal(t, time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC), info.Time) + require.Equal(t, []byte{0x26, 0xb0, 0xb8, 0x3e, 0x72, 0x81, 0xbe, 0x3b, 0x11, 0x76, 0x58, 0xb6, 0xf2, 0x63, 0x6d, 0x3, 0x68, 0xca, 0xd3, 0xd7, 0x4f, 0x22, 0x24, 0x34, 0x28, 0xf5, 0x40, 0x1a, 0x4b, 0x70, 0x89, 0x7e}, info.AppHash) + require.Equal(t, "test-chain", info.ChainID) + +} diff --git a/core/transaction/transaction.go b/core/transaction/transaction.go index be1e2960ad94..2f6b8ca99130 100644 --- a/core/transaction/transaction.go +++ b/core/transaction/transaction.go @@ -20,12 +20,12 @@ type Tx interface { // Hash returns the unique identifier for the Tx. Hash() [32]byte // TODO evaluate if 32 bytes is the right size & benchmark overhead of hashing instead of using identifier // GetMessages returns the list of state transitions of the Tx. - GetMessages() []Type + GetMessages() ([]Type, error) // GetSenders returns the tx state transition sender. - GetSenders() []Identity // TODO reduce this to a single identity if accepted + GetSenders() ([]Identity, error) // TODO reduce this to a single identity if accepted // GetGasLimit returns the gas limit of the tx. Must return math.MaxUint64 for infinite gas // txs. - GetGasLimit() uint64 + GetGasLimit() (uint64, error) // Bytes returns the encoded version of this tx. Note: this is ideally cached // from the first instance of the decoding of the tx. Bytes() []byte diff --git a/go.work.example b/go.work.example index 428d582dcf5d..32417bd6ef41 100644 --- a/go.work.example +++ b/go.work.example @@ -15,6 +15,7 @@ use ( ./orm ./simapp ./tests + ./server/v2/stf ./store ./store/v2 ./tools/cosmovisor diff --git a/runtime/gas.go b/runtime/gas.go index ec7f5014047c..50121da9e069 100644 --- a/runtime/gas.go +++ b/runtime/gas.go @@ -59,11 +59,15 @@ func (gm SDKGasMeter) Limit() storetypes.Gas { } func (gm SDKGasMeter) ConsumeGas(amount storetypes.Gas, descriptor string) { - gm.gm.Consume(amount, descriptor) + if err := gm.gm.Consume(amount, descriptor); err != nil { + panic(err) + } } func (gm SDKGasMeter) RefundGas(amount storetypes.Gas, descriptor string) { - gm.gm.Refund(amount, descriptor) + if err := gm.gm.Refund(amount, descriptor); err != nil { + panic(err) + } } func (gm SDKGasMeter) IsPastLimit() bool { @@ -83,12 +87,14 @@ type CoreGasmeter struct { gm storetypes.GasMeter } -func (cgm CoreGasmeter) Consume(amount gas.Gas, descriptor string) { +func (cgm CoreGasmeter) Consume(amount gas.Gas, descriptor string) error { cgm.gm.ConsumeGas(amount, descriptor) + return nil } -func (cgm CoreGasmeter) Refund(amount gas.Gas, descriptor string) { +func (cgm CoreGasmeter) Refund(amount gas.Gas, descriptor string) error { cgm.gm.RefundGas(amount, descriptor) + return nil } func (cgm CoreGasmeter) Remaining() gas.Gas { diff --git a/server/v2/stf/README.md b/server/v2/stf/README.md new file mode 100644 index 000000000000..48c40ca75b67 --- /dev/null +++ b/server/v2/stf/README.md @@ -0,0 +1,35 @@ +# State Transition Function (STF) + +STF is a function that takes a state and an action as input and returns the next state. It does not assume the execution model of the application nor consensus. + + +The state transition function receives a read only instance of state. It does not directly write to disk, instead it will return the state changes which has undergone within the application. The state transition function is deterministic, meaning that given the same input, it will always produce the same output. + +## BranchDB + +BranchDB is a cache of all the reads done within a block, simulation or transaction validation. It takes a read-only instance of state and creates its own write instance using a btree. After all state transitions are done, the new change sets are returned to the caller. + +The BranchDB can be replaced and optimized for specific use cases. The implementation is as follows + +```go + type branchdb func(state store.ReaderMap) store.WriterMap +``` + +## GasMeter + +GasMeter is a utility that keeps track of the gas consumed by the state transition function. It is used to limit the amount of computation that can be done within a block. + +The GasMeter can be replaced and optimized for specific use cases. The implementation is as follows: + +```go +type ( + // gasMeter is a function type that takes a gas limit as input and returns a gas.Meter. + // It is used to measure and limit the amount of gas consumed during the execution of a function. + gasMeter func(gasLimit uint64) gas.Meter + + // wrapGasMeter is a function type that wraps a gas meter and a store writer map. + wrapGasMeter func(meter gas.Meter, store store.WriterMap) store.WriterMap +) +``` + +THe wrappGasMeter is used in order to consume gas. Application developers can seamlsessly replace the gas meter with their own implementation in order to customize consumption of gas. diff --git a/server/v2/stf/branch/branch_test.go b/server/v2/stf/branch/branch_test.go new file mode 100644 index 000000000000..e306a2cadfdc --- /dev/null +++ b/server/v2/stf/branch/branch_test.go @@ -0,0 +1,151 @@ +package branch + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/tidwall/btree" + + "cosmossdk.io/core/store" +) + +func TestBranch(t *testing.T) { + set := func(s interface{ Set([]byte, []byte) error }, key, value string) { + require.NoError(t, s.Set([]byte(key), []byte(value))) + } + get := func(s interface{ Get([]byte) ([]byte, error) }, key, wantValue string) { + value, err := s.Get([]byte(key)) + require.NoError(t, err) + if wantValue == "" { + require.Nil(t, value) + } else { + require.Equal(t, wantValue, string(value)) + } + } + + remove := func(s interface{ Delete([]byte) error }, key string) { + err := s.Delete([]byte(key)) + require.NoError(t, err) + } + + iter := func(s interface { + Iterator(start, end []byte) (store.Iterator, error) + }, start, end string, wantPairs [][2]string, + ) { + startKey := []byte(start) + endKey := []byte(end) + if start == "" { + startKey = nil + } + if end == "" { + endKey = nil + } + iter, err := s.Iterator(startKey, endKey) + require.NoError(t, err) + defer iter.Close() + numPairs := len(wantPairs) + for i := 0; i < numPairs; i++ { + require.True(t, iter.Valid(), "expected iterator to be valid") + gotKey, gotValue := string(iter.Key()), string(iter.Value()) + wantKey, wantValue := wantPairs[i][0], wantPairs[i][1] + require.Equal(t, wantKey, gotKey) + require.Equal(t, wantValue, gotValue) + iter.Next() + } + } + + parent := newMemState() + + // populate parent with some state + set(parent, "1", "a") + set(parent, "2", "b") + set(parent, "3", "c") + set(parent, "4", "d") + + branch := NewStore(parent) + + get(branch, "1", "a") // gets from parent + + set(branch, "1", "z") + get(branch, "1", "z") // gets updated value from branch + + set(branch, "5", "e") + get(branch, "5", "e") // gets updated value from branch + + remove(branch, "3") + get(branch, "3", "") // it's not fetched even if it exists in parent, it's not part of branch changeset currently. + + set(branch, "6", "f") + remove(branch, "6") + get(branch, "6", "") // inserted and then removed from branch + + // test iter + iter( + branch, + "", "", + [][2]string{ + {"1", "z"}, + {"2", "b"}, + {"4", "d"}, + {"5", "e"}, + }, + ) + + // test iter in range + iter( + branch, + "2", "4", + [][2]string{ + {"2", "b"}, + }, + ) + + // test reverse iter +} + +func newMemState() memStore { + return memStore{btree.NewBTreeGOptions(byKeys, btree.Options{Degree: bTreeDegree, NoLocks: true})} +} + +var _ store.Writer = memStore{} + +type memStore struct { + t *btree.BTreeG[item] +} + +func (m memStore) Set(key, value []byte) error { + m.t.Set(item{key: key, value: value}) + return nil +} + +func (m memStore) Delete(key []byte) error { + m.t.Delete(item{key: key}) + return nil +} + +func (m memStore) ApplyChangeSets(changes []store.KVPair) error { + panic("not callable") +} + +func (m memStore) ChangeSets() ([]store.KVPair, error) { panic("not callable") } + +func (m memStore) Has(key []byte) (bool, error) { + _, found := m.t.Get(item{key: key}) + return found, nil +} + +func (m memStore) Get(bytes []byte) ([]byte, error) { + v, found := m.t.Get(item{key: bytes}) + if !found { + return nil, nil + } + return v.value, nil +} + +func (m memStore) Iterator(start, end []byte) (store.Iterator, error) { + return newMemIterator(start, end, m.t, true), nil +} + +func (m memStore) ReverseIterator(start, end []byte) (store.Iterator, error) { + return newMemIterator(start, end, m.t, false), nil +} diff --git a/server/v2/stf/branch/changeset.go b/server/v2/stf/branch/changeset.go new file mode 100644 index 000000000000..13c016725130 --- /dev/null +++ b/server/v2/stf/branch/changeset.go @@ -0,0 +1,230 @@ +package branch + +import ( + "bytes" + "errors" + + "github.com/tidwall/btree" + + "cosmossdk.io/core/store" +) + +const ( + // The approximate number of items and children per B-tree node. Tuned with benchmarks. + // copied from memdb. + bTreeDegree = 32 +) + +var errKeyEmpty = errors.New("key cannot be empty") + +// changeSet implements the sorted cache for cachekv store, +// we don't use MemDB here because cachekv is used extensively in sdk core path, +// we need it to be as fast as possible, while `MemDB` is mainly used as a mocking db in unit tests. +// +// We choose tidwall/btree over google/btree here because it provides API to implement step iterator directly. +type changeSet struct { + tree *btree.BTreeG[item] +} + +// newChangeSet creates a wrapper around `btree.BTreeG`. +func newChangeSet() changeSet { + return changeSet{ + tree: btree.NewBTreeGOptions(byKeys, btree.Options{ + Degree: bTreeDegree, + NoLocks: true, + }), + } +} + +// set adds a new key-value pair to the change set's tree. +func (bt changeSet) set(key, value []byte) { + bt.tree.Set(newItem(key, value)) +} + +// get retrieves the value associated with the given key from the changeSet's tree. +func (bt changeSet) get(key []byte) (value []byte, found bool) { + it, found := bt.tree.Get(item{key: key}) + return it.value, found +} + +// delete removes the value associated with the given key from the change set. +// If the key does not exist in the change set, this method does nothing. +func (bt changeSet) delete(key []byte) { + bt.set(key, nil) +} + +// iterator returns a new iterator over the key-value pairs in the changeSet +// that have keys greater than or equal to the start key and less than the end key. +func (bt changeSet) iterator(start, end []byte) (store.Iterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, errKeyEmpty + } + return newMemIterator(start, end, bt.tree, true), nil +} + +// reverseIterator returns a new iterator that iterates over the key-value pairs in reverse order +// within the specified range [start, end) in the changeSet's tree. +// If start or end is an empty byte slice, it returns an error indicating that the key is empty. +func (bt changeSet) reverseIterator(start, end []byte) (store.Iterator, error) { + if (start != nil && len(start) == 0) || (end != nil && len(end) == 0) { + return nil, errKeyEmpty + } + return newMemIterator(start, end, bt.tree, false), nil +} + +// item is a btree item with byte slices as keys and values +type item struct { + key []byte + value []byte +} + +// byKeys compares the items by key +func byKeys(a, b item) bool { + return bytes.Compare(a.key, b.key) == -1 +} + +// newItem creates a new pair item. +func newItem(key, value []byte) item { + return item{key: key, value: value} +} + +// memIterator iterates over iterKVCache items. +// if value is nil, means it was deleted. +// Implements Iterator. +type memIterator struct { + iter btree.IterG[item] + + start []byte + end []byte + ascending bool + valid bool +} + +// newMemIterator creates a new memory iterator for a given range of keys in a B-tree. +// The iterator starts at the specified start key and ends at the specified end key. +// The `tree` parameter is the B-tree to iterate over. +// The `ascending` parameter determines the direction of iteration. +// If `ascending` is true, the iterator will iterate in ascending order. +// If `ascending` is false, the iterator will iterate in descending order. +// The returned iterator is positioned at the first key that is greater than or equal to the start key. +// If the start key is nil, the iterator is positioned at the first key in the B-tree. +// If the end key is nil, the iterator is positioned at the last key in the B-tree. +// The iterator is inclusive of the start key and exclusive of the end key. +// The `valid` field of the iterator indicates whether the iterator is positioned at a valid key. +// The `start` and `end` fields of the iterator store the start and end keys respectively. +func newMemIterator(start, end []byte, tree *btree.BTreeG[item], ascending bool) *memIterator { + iter := tree.Iter() + var valid bool + if ascending { + if start != nil { + valid = iter.Seek(newItem(start, nil)) + } else { + valid = iter.First() + } + } else { + if end != nil { + valid = iter.Seek(newItem(end, nil)) + if !valid { + valid = iter.Last() + } else { + // end is exclusive + valid = iter.Prev() + } + } else { + valid = iter.Last() + } + } + + mi := &memIterator{ + iter: iter, + start: start, + end: end, + ascending: ascending, + valid: valid, + } + + if mi.valid { + mi.valid = mi.keyInRange(mi.Key()) + } + + return mi +} + +// Domain returns the start and end keys of the iterator's domain. +func (mi *memIterator) Domain() (start, end []byte) { + return mi.start, mi.end +} + +// Close releases any resources held by the iterator. +func (mi *memIterator) Close() error { + mi.iter.Release() + return nil +} + +// Error returns the error state of the iterator. +// If the iterator is not valid, it returns the errInvalidIterator error. +// Otherwise, it returns nil. +func (mi *memIterator) Error() error { + if !mi.Valid() { + return errInvalidIterator + } + return nil +} + +// Valid returns whether the iterator is currently pointing to a valid entry. +// It returns true if the iterator is valid, and false otherwise. +func (mi *memIterator) Valid() bool { + return mi.valid +} + +// Next advances the iterator to the next key-value pair. +// If the iterator is in ascending order, it moves to the next key-value pair. +// If the iterator is in descending order, it moves to the previous key-value pair. +// It also checks if the new key-value pair is within the specified range. +func (mi *memIterator) Next() { + mi.assertValid() + + if mi.ascending { + mi.valid = mi.iter.Next() + } else { + mi.valid = mi.iter.Prev() + } + + if mi.valid { + mi.valid = mi.keyInRange(mi.Key()) + } +} + +// keyInRange checks if the given key is within the range defined by the iterator. +// If the iterator is in ascending order and the end key is not nil, it returns false +// if the key is greater than or equal to the end key. +// If the iterator is in descending order and the start key is not nil, it returns false +// if the key is less than the start key. +// Otherwise, it returns true. +func (mi *memIterator) keyInRange(key []byte) bool { + if mi.ascending && mi.end != nil && bytes.Compare(key, mi.end) >= 0 { + return false + } + if !mi.ascending && mi.start != nil && bytes.Compare(key, mi.start) < 0 { + return false + } + return true +} + +// Key returns the key of the current item in the iterator. +func (mi *memIterator) Key() []byte { + return mi.iter.Item().key +} + +// Value returns the value of the current item in the iterator. +func (mi *memIterator) Value() []byte { + return mi.iter.Item().value +} + +// assertValid checks if the memIterator is in a valid state. +// If there is an error, it panics with the error message. +func (mi *memIterator) assertValid() { + if err := mi.Error(); err != nil { + panic(err) + } +} diff --git a/server/v2/stf/branch/defaults.go b/server/v2/stf/branch/defaults.go new file mode 100644 index 000000000000..19f68933f9d3 --- /dev/null +++ b/server/v2/stf/branch/defaults.go @@ -0,0 +1,9 @@ +package branch + +import "cosmossdk.io/core/store" + +func DefaultNewWriterMap(r store.ReaderMap) store.WriterMap { + return NewWriterMap(r, func(readonlyState store.Reader) store.Writer { + return NewStore(readonlyState) + }) +} diff --git a/server/v2/stf/branch/doc.go b/server/v2/stf/branch/doc.go new file mode 100644 index 000000000000..9fc02d7261c4 --- /dev/null +++ b/server/v2/stf/branch/doc.go @@ -0,0 +1,3 @@ +// Package branch defines a Store that can be used to wrap readable state to make it writable. +// Code heavily taken and adapted from cosmossdk.io/store/v1. +package branch diff --git a/server/v2/stf/branch/mergeiter.go b/server/v2/stf/branch/mergeiter.go new file mode 100644 index 000000000000..e71b88cffc42 --- /dev/null +++ b/server/v2/stf/branch/mergeiter.go @@ -0,0 +1,235 @@ +package branch + +import ( + "bytes" + "errors" + + corestore "cosmossdk.io/core/store" +) + +// mergedIterator merges a parent Iterator and a cache Iterator. +// The cache iterator may return nil keys to signal that an item +// had been deleted (but not deleted in the parent). +// If the cache iterator has the same key as the parent, the +// cache shadows (overrides) the parent. +type mergedIterator struct { + parent corestore.Iterator + cache corestore.Iterator + ascending bool + + valid bool +} + +var _ corestore.Iterator = (*mergedIterator)(nil) + +// mergeIterators merges two iterators. +func mergeIterators(parent, cache corestore.Iterator, ascending bool) corestore.Iterator { + iter := &mergedIterator{ + parent: parent, + cache: cache, + ascending: ascending, + } + + iter.valid = iter.skipUntilExistsOrInvalid() + return iter +} + +// Domain implements Iterator. +// Returns parent domain because cache and parent domains are the same. +func (iter *mergedIterator) Domain() (start, end []byte) { + return iter.parent.Domain() +} + +// Valid implements Iterator. +func (iter *mergedIterator) Valid() bool { + return iter.valid +} + +// Next implements Iterator +func (iter *mergedIterator) Next() { + iter.assertValid() + + switch { + case !iter.parent.Valid(): + // If parent is invalid, get the next cache item. + iter.cache.Next() + case !iter.cache.Valid(): + // If cache is invalid, get the next parent item. + iter.parent.Next() + default: + // Both are valid. Compare keys. + keyP, keyC := iter.parent.Key(), iter.cache.Key() + switch iter.compare(keyP, keyC) { + case -1: // parent < cache + iter.parent.Next() + case 0: // parent == cache + iter.parent.Next() + iter.cache.Next() + case 1: // parent > cache + iter.cache.Next() + } + } + iter.valid = iter.skipUntilExistsOrInvalid() +} + +// Key implements Iterator +func (iter *mergedIterator) Key() []byte { + iter.assertValid() + + // If parent is invalid, get the cache key. + if !iter.parent.Valid() { + return iter.cache.Key() + } + + // If cache is invalid, get the parent key. + if !iter.cache.Valid() { + return iter.parent.Key() + } + + // Both are valid. Compare keys. + keyP, keyC := iter.parent.Key(), iter.cache.Key() + + cmp := iter.compare(keyP, keyC) + switch cmp { + case -1: // parent < cache + return keyP + case 0: // parent == cache + return keyP + case 1: // parent > cache + return keyC + default: + panic("invalid compare result") + } +} + +// Value implements Iterator +func (iter *mergedIterator) Value() []byte { + iter.assertValid() + + // If parent is invalid, get the cache value. + if !iter.parent.Valid() { + return iter.cache.Value() + } + + // If cache is invalid, get the parent value. + if !iter.cache.Valid() { + return iter.parent.Value() + } + + // Both are valid. Compare keys. + keyP, keyC := iter.parent.Key(), iter.cache.Key() + + cmp := iter.compare(keyP, keyC) + switch cmp { + case -1: // parent < cache + return iter.parent.Value() + case 0: // parent == cache + return iter.cache.Value() + case 1: // parent > cache + return iter.cache.Value() + default: + panic("invalid comparison result") + } +} + +// Close implements Iterator +func (iter *mergedIterator) Close() error { + err1 := iter.cache.Close() + if err := iter.parent.Close(); err != nil { + return err + } + + return err1 +} + +var errInvalidIterator = errors.New("invalid merged iterator") + +// Error returns an error if the mergedIterator is invalid defined by the +// Valid method. +func (iter *mergedIterator) Error() error { + if !iter.Valid() { + return errInvalidIterator + } + + return nil +} + +// If not valid, panics. +// NOTE: May have side-effect of iterating over cache. +func (iter *mergedIterator) assertValid() { + if err := iter.Error(); err != nil { + panic(err) + } +} + +// Like bytes.Compare but opposite if not ascending. +func (iter *mergedIterator) compare(a, b []byte) int { + if iter.ascending { + return bytes.Compare(a, b) + } + + return bytes.Compare(a, b) * -1 +} + +// Skip all delete-items from the cache w/ `key < until`. After this function, +// current cache item is a non-delete-item, or `until <= key`. +// If the current cache item is not a delete item, does nothing. +// If `until` is nil, there is no limit, and cache may end up invalid. +// CONTRACT: cache is valid. +func (iter *mergedIterator) skipCacheDeletes(until []byte) { + for iter.cache.Valid() && + iter.cache.Value() == nil && + (until == nil || iter.compare(iter.cache.Key(), until) < 0) { + iter.cache.Next() + } +} + +// Fast forwards cache (or parent+cache in case of deleted items) until current +// item exists, or until iterator becomes invalid. +// Returns whether the iterator is valid. +func (iter *mergedIterator) skipUntilExistsOrInvalid() bool { + for { + // If parent is invalid, fast-forward cache. + if !iter.parent.Valid() { + iter.skipCacheDeletes(nil) + return iter.cache.Valid() + } + // Parent is valid. + + if !iter.cache.Valid() { + return true + } + // Parent is valid, cache is valid. + + // Compare parent and cache. + keyP := iter.parent.Key() + keyC := iter.cache.Key() + + switch iter.compare(keyP, keyC) { + case -1: // parent < cache. + return true + + case 0: // parent == cache. + // Skip over if cache item is a delete. + valueC := iter.cache.Value() + if valueC == nil { + iter.parent.Next() + iter.cache.Next() + + continue + } + // Cache is not a delete. + + return true // cache exists. + case 1: // cache < parent + // Skip over if cache item is a delete. + valueC := iter.cache.Value() + if valueC == nil { + iter.skipCacheDeletes(keyP) + continue + } + // Cache is not a delete. + return true // cache exists. + } + } +} diff --git a/server/v2/stf/branch/store.go b/server/v2/stf/branch/store.go new file mode 100644 index 000000000000..f0d6d0b3a1ea --- /dev/null +++ b/server/v2/stf/branch/store.go @@ -0,0 +1,134 @@ +package branch + +import ( + "errors" + + "cosmossdk.io/core/store" +) + +var _ store.Writer = (*Store[store.Reader])(nil) + +// Store wraps an in-memory cache around an underlying types.KVStore. +type Store[T store.Reader] struct { + changeSet changeSet // always ascending sorted + parent T +} + +// NewStore creates a new Store object +func NewStore[T store.Reader](parent T) Store[T] { + return Store[T]{ + changeSet: newChangeSet(), + parent: parent, + } +} + +// Get implements types.KVStore. +func (s Store[T]) Get(key []byte) (value []byte, err error) { + value, found := s.changeSet.get(key) + if found { + return + } + return s.parent.Get(key) +} + +// Set implements types.KVStore. +func (s Store[T]) Set(key, value []byte) error { + if value == nil { + return errors.New("cannot set a nil value") + } + + s.changeSet.set(key, value) + return nil +} + +// Has implements types.KVStore. +func (s Store[T]) Has(key []byte) (bool, error) { + tmpValue, found := s.changeSet.get(key) + if found { + return tmpValue != nil, nil + } + return s.parent.Has(key) +} + +// Delete implements types.KVStore. +func (s Store[T]) Delete(key []byte) error { + s.changeSet.delete(key) + return nil +} + +// ---------------------------------------- +// Iteration + +// Iterator implements types.KVStore. +func (s Store[T]) Iterator(start, end []byte) (store.Iterator, error) { + return s.iterator(start, end, true) +} + +// ReverseIterator implements types.KVStore. +func (s Store[T]) ReverseIterator(start, end []byte) (store.Iterator, error) { + return s.iterator(start, end, false) +} + +func (s Store[T]) iterator(start, end []byte, ascending bool) (store.Iterator, error) { + var ( + err error + parent, cache store.Iterator + ) + + if ascending { + parent, err = s.parent.Iterator(start, end) + if err != nil { + return nil, err + } + cache, err = s.changeSet.iterator(start, end) + if err != nil { + return nil, err + } + return mergeIterators(parent, cache, ascending), nil + } else { + parent, err = s.parent.ReverseIterator(start, end) + if err != nil { + return nil, err + } + cache, err = s.changeSet.reverseIterator(start, end) + if err != nil { + return nil, err + } + return mergeIterators(parent, cache, ascending), nil + } +} + +func (s Store[T]) ApplyChangeSets(changes []store.KVPair) error { + for _, c := range changes { + if c.Remove { + err := s.Delete(c.Key) + if err != nil { + return err + } + } else { + err := s.Set(c.Key, c.Value) + if err != nil { + return err + } + } + } + return nil +} + +func (s Store[T]) ChangeSets() (cs []store.KVPair, err error) { + iter, err := s.changeSet.iterator(nil, nil) + if err != nil { + return nil, err + } + defer iter.Close() + + for ; iter.Valid(); iter.Next() { + k, v := iter.Key(), iter.Value() + cs = append(cs, store.KVPair{ + Key: k, + Value: v, + Remove: v == nil, // maybe we can optimistically compute size. + }) + } + return cs, nil +} diff --git a/server/v2/stf/branch/writer_map.go b/server/v2/stf/branch/writer_map.go new file mode 100644 index 000000000000..b624a4d2532c --- /dev/null +++ b/server/v2/stf/branch/writer_map.go @@ -0,0 +1,79 @@ +package branch + +import ( + "fmt" + "unsafe" + + "cosmossdk.io/core/store" +) + +func NewWriterMap( + state store.ReaderMap, + branch func(readonlyState store.Reader) store.Writer, +) store.WriterMap { + return WriterMap{ + state: state, + branchedWriterState: make(map[string]store.Writer), + branch: branch, + } +} + +// WriterMap implements a branched version of the store.WriterMap. +// After the firs time the actor's branched Store is created, it is +// memoized in the WriterMap. +type WriterMap struct { + state store.ReaderMap + branchedWriterState map[string]store.Writer + branch func(state store.Reader) store.Writer +} + +func (b WriterMap) GetReader(actor []byte) (store.Reader, error) { + return b.GetWriter(actor) +} + +func (b WriterMap) GetWriter(actor []byte) (store.Writer, error) { + // Simplify and optimize state retrieval + if actorState, ok := b.branchedWriterState[unsafeString(actor)]; ok { + return actorState, nil + } else if writerState, err := b.state.GetReader(actor); err != nil { + return nil, err + } else { + actorState = b.branch(writerState) + b.branchedWriterState[string(actor)] = actorState + return actorState, nil + } +} + +func (b WriterMap) ApplyStateChanges(stateChanges []store.StateChanges) error { + for _, sc := range stateChanges { + if err := b.applyStateChange(sc); err != nil { + return fmt.Errorf("unable to apply state change for actor %X: %w", sc.Actor, err) + } + } + return nil +} + +func (b WriterMap) GetStateChanges() ([]store.StateChanges, error) { + sc := make([]store.StateChanges, len(b.branchedWriterState)) + for account, stateChange := range b.branchedWriterState { + kvChanges, err := stateChange.ChangeSets() + if err != nil { + return nil, err + } + sc = append(sc, store.StateChanges{ + Actor: []byte(account), + StateChanges: kvChanges, + }) + } + return sc, nil +} + +func (b WriterMap) applyStateChange(sc store.StateChanges) error { + writableState, err := b.GetWriter(sc.Actor) + if err != nil { + return err + } + return writableState.ApplyChangeSets(sc.StateChanges) +} + +func unsafeString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } diff --git a/server/v2/stf/core_branch_service.go b/server/v2/stf/core_branch_service.go new file mode 100644 index 000000000000..365d73d532b4 --- /dev/null +++ b/server/v2/stf/core_branch_service.go @@ -0,0 +1,75 @@ +package stf + +import ( + "context" + + "cosmossdk.io/core/branch" + "cosmossdk.io/core/store" +) + +type branchFn func(state store.ReaderMap) store.WriterMap + +var _ branch.Service = (*BranchService)(nil) + +type BranchService struct{} + +func (bs BranchService) Execute(ctx context.Context, f func(ctx context.Context) error) error { + return bs.execute(ctx.(*executionContext), f) +} + +func (bs BranchService) ExecuteWithGasLimit( + ctx context.Context, + gasLimit uint64, + f func(ctx context.Context) error, +) (gasUsed uint64, err error) { + stfCtx := ctx.(*executionContext) + + originalGasMeter := stfCtx.meter + + stfCtx.setGasLimit(gasLimit) + + // execute branched, with predefined gas limit. + err = bs.execute(stfCtx, f) + // restore original context + gasUsed = stfCtx.meter.Limit() - stfCtx.meter.Remaining() + _ = originalGasMeter.Consume(gasUsed, "execute-with-gas-limit") + stfCtx.setGasLimit(originalGasMeter.Limit() - originalGasMeter.Remaining()) + + return gasUsed, err +} + +func (bs BranchService) execute(ctx *executionContext, f func(ctx context.Context) error) error { + branchedState := ctx.branchFn(ctx.unmeteredState) + meteredBranchedState := ctx.makeGasMeteredStore(ctx.meter, branchedState) + + branchedCtx := &executionContext{ + Context: ctx.Context, + unmeteredState: branchedState, + state: meteredBranchedState, + meter: ctx.meter, + events: nil, + sender: ctx.sender, + headerInfo: ctx.headerInfo, + execMode: ctx.execMode, + branchFn: ctx.branchFn, + makeGasMeter: ctx.makeGasMeter, + makeGasMeteredStore: ctx.makeGasMeteredStore, + } + + err := f(branchedCtx) + if err != nil { + return err + } + + // apply state changes to original state + if len(branchedCtx.events) != 0 { + ctx.events = append(ctx.events, branchedCtx.events...) + } + + err = applyStateChanges(ctx.state, branchedCtx.unmeteredState) + if err != nil { + return err + } + + return nil +} diff --git a/server/v2/stf/core_branch_service_test.go b/server/v2/stf/core_branch_service_test.go new file mode 100644 index 000000000000..6eee1581377e --- /dev/null +++ b/server/v2/stf/core_branch_service_test.go @@ -0,0 +1,93 @@ +package stf + +import ( + "context" + "fmt" + "testing" + + appmodulev2 "cosmossdk.io/core/appmodule/v2" + "cosmossdk.io/core/transaction" + "cosmossdk.io/server/v2/stf/branch" + "cosmossdk.io/server/v2/stf/gas" + "cosmossdk.io/server/v2/stf/mock" + "github.com/stretchr/testify/require" +) + +func TestBranchService(t *testing.T) { + s := &STF[mock.Tx]{ + handleMsg: func(ctx context.Context, msg transaction.Type) (msgResp transaction.Type, err error) { + kvSet(t, ctx, "exec") + return nil, nil + }, + handleQuery: nil, + doPreBlock: func(ctx context.Context, txs []mock.Tx) error { return nil }, + doBeginBlock: func(ctx context.Context) error { + kvSet(t, ctx, "begin-block") + return nil + }, + doEndBlock: func(ctx context.Context) error { + kvSet(t, ctx, "end-block") + return nil + }, + doValidatorUpdate: func(ctx context.Context) ([]appmodulev2.ValidatorUpdate, error) { return nil, nil }, + doTxValidation: func(ctx context.Context, tx mock.Tx) error { + kvSet(t, ctx, "validate") + return nil + }, + postTxExec: func(ctx context.Context, tx mock.Tx, success bool) error { + kvSet(t, ctx, "post-tx-exec") + return nil + }, + branchFn: branch.DefaultNewWriterMap, + makeGasMeter: gas.DefaultGasMeter, + makeGasMeteredState: gas.DefaultWrapWithGasMeter, + } + + makeContext := func() *executionContext { + state := mock.DB() + writableState := s.branchFn(state) + ctx := s.makeContext(context.Background(), []byte("cookies"), writableState, 0) + ctx.setGasLimit(1000000) + return ctx + } + + branchService := BranchService{} + + // TODO: add events check + gas limit precision test + + t.Run("ok", func(t *testing.T) { + stfCtx := makeContext() + gasUsed, err := branchService.ExecuteWithGasLimit(stfCtx, 10000, func(ctx context.Context) error { + kvSet(t, ctx, "cookies") + return nil + }) + require.NoError(t, err) + require.NotZero(t, gasUsed) + stateHas(t, stfCtx.state, "cookies") + }) + + t.Run("fail - reverts state", func(t *testing.T) { + stfCtx := makeContext() + gasUsed, err := branchService.ExecuteWithGasLimit(stfCtx, 10000, func(ctx context.Context) error { + kvSet(t, ctx, "cookies") + return fmt.Errorf("fail") + }) + require.Error(t, err) + require.NotZero(t, gasUsed) + stateNotHas(t, stfCtx.state, "cookies") + }) + + t.Run("fail - out of gas", func(t *testing.T) { + stfCtx := makeContext() + + gasUsed, err := branchService.ExecuteWithGasLimit(stfCtx, 4000, func(ctx context.Context) error { + state, _ := ctx.(*executionContext).state.GetWriter(actorName) + _ = state.Set([]byte("not out of gas"), []byte{}) + return state.Set([]byte("out of gas"), []byte{}) + }) + require.Error(t, err) + require.NotZero(t, gasUsed) + stateNotHas(t, stfCtx.state, "cookies") + require.Equal(t, uint64(1000), stfCtx.meter.Limit()-stfCtx.meter.Remaining()) + }) +} diff --git a/server/v2/stf/core_event_service.go b/server/v2/stf/core_event_service.go new file mode 100644 index 000000000000..8258742d8896 --- /dev/null +++ b/server/v2/stf/core_event_service.go @@ -0,0 +1,87 @@ +package stf + +import ( + "context" + "encoding/json" + "slices" + + gogoproto "github.com/cosmos/gogoproto/proto" + "golang.org/x/exp/maps" + "google.golang.org/protobuf/runtime/protoiface" + + "cosmossdk.io/core/event" +) + +func NewEventService() event.Service { + return eventService{} +} + +type eventService struct{} + +// EventManager implements event.Service. +func (eventService) EventManager(ctx context.Context) event.Manager { + return &eventManager{ctx.(*executionContext)} +} + +var _ event.Manager = (*eventManager)(nil) + +type eventManager struct { + executionContext *executionContext +} + +// Emit emits an typed event that is defined in the protobuf file. +// In the future these events will be added to consensus. +func (em *eventManager) Emit(tev protoiface.MessageV1) error { + res, err := TypedEventToEvent(tev) + if err != nil { + return err + } + + em.executionContext.events = append(em.executionContext.events, res) + return nil +} + +// EmitKV emits a key value pair event. +func (em *eventManager) EmitKV(eventType string, attrs ...event.Attribute) error { + em.executionContext.events = append(em.executionContext.events, event.NewEvent(eventType, attrs...)) + return nil +} + +// EmitNonConsensus emits an typed event that is defined in the protobuf file. +// These events will not be added to consensus. +func (em *eventManager) EmitNonConsensus(event protoiface.MessageV1) error { + return em.Emit(event) +} + +// TypedEventToEvent takes typed event and converts to Event object +func TypedEventToEvent(tev gogoproto.Message) (event.Event, error) { + evtType := gogoproto.MessageName(tev) + evtJSON, err := gogoproto.Marshal(tev) + if err != nil { + return event.Event{}, err + } + + var attrMap map[string]json.RawMessage + err = json.Unmarshal(evtJSON, &attrMap) + if err != nil { + return event.Event{}, err + } + + // sort the keys to ensure the order is always the same + keys := maps.Keys(attrMap) + slices.Sort(keys) + + attrs := make([]event.Attribute, 0, len(attrMap)) + for _, k := range keys { + v := attrMap[k] + attrs = append(attrs, event.Attribute{ + Key: k, + Value: string(v), + }) + } + + return event.Event{ + Type: evtType, + Attributes: attrs, + }, nil +} diff --git a/server/v2/stf/core_gas_service.go b/server/v2/stf/core_gas_service.go new file mode 100644 index 000000000000..656fd23388ca --- /dev/null +++ b/server/v2/stf/core_gas_service.go @@ -0,0 +1,45 @@ +package stf + +import ( + "context" + + "cosmossdk.io/core/gas" + "cosmossdk.io/core/store" +) + +type ( + // makeGasMeterFn is a function type that takes a gas limit as input and returns a gas.Meter. + // It is used to measure and limit the amount of gas consumed during the execution of a function. + makeGasMeterFn func(gasLimit uint64) gas.Meter + + // makeGasMeteredStateFn is a function type that wraps a gas meter and a store writer map. + makeGasMeteredStateFn func(meter gas.Meter, store store.WriterMap) store.WriterMap +) + +// NewGasMeterService creates a new instance of the gas meter service. +func NewGasMeterService() gas.Service { + return gasService{} +} + +type gasService struct{} + +// GetGasConfig implements gas.Service. +func (g gasService) GasConfig(ctx context.Context) gas.GasConfig { + panic("unimplemented") +} + +func (g gasService) GasMeter(ctx context.Context) gas.Meter { + return ctx.(*executionContext).meter +} + +func (g gasService) BlockGasMeter(ctx context.Context) gas.Meter { + panic("stf has no block gas meter") +} + +func (g gasService) WithGasMeter(ctx context.Context, meter gas.Meter) context.Context { + panic("unimplemented") +} + +func (g gasService) WithBlockGasMeter(ctx context.Context, meter gas.Meter) context.Context { + panic("unimplemented") +} diff --git a/server/v2/stf/core_header_service.go b/server/v2/stf/core_header_service.go new file mode 100644 index 000000000000..730ce7e40646 --- /dev/null +++ b/server/v2/stf/core_header_service.go @@ -0,0 +1,17 @@ +package stf + +import ( + "context" + + "cosmossdk.io/core/header" +) + +var _ header.Service = (*HeaderService)(nil) + +type HeaderService struct { + getHeader func() (header.Info, error) +} + +func (h HeaderService) HeaderInfo(ctx context.Context) header.Info { + return ctx.(*executionContext).headerInfo +} diff --git a/server/v2/stf/core_store_service.go b/server/v2/stf/core_store_service.go new file mode 100644 index 000000000000..d912f9277157 --- /dev/null +++ b/server/v2/stf/core_store_service.go @@ -0,0 +1,33 @@ +package stf + +import ( + "context" + + "cosmossdk.io/core/store" +) + +var _ store.KVStoreService = (*storeService)(nil) + +func NewKVStoreService(address []byte) store.KVStoreService { + return storeService{actor: address} +} + +func NewMemoryStoreService(address []byte) store.MemoryStoreService { + return storeService{actor: address} +} + +type storeService struct { + actor []byte +} + +func (s storeService) OpenKVStore(ctx context.Context) store.KVStore { + state, err := ctx.(*executionContext).state.GetWriter(s.actor) + if err != nil { + panic(err) + } + return state +} + +func (s storeService) OpenMemoryStore(ctx context.Context) store.KVStore { + return s.OpenKVStore(ctx) +} diff --git a/server/v2/stf/export_test.go b/server/v2/stf/export_test.go new file mode 100644 index 000000000000..427374b9cd15 --- /dev/null +++ b/server/v2/stf/export_test.go @@ -0,0 +1,14 @@ +package stf + +import ( + "context" +) + +func GetExecutionContext(ctx context.Context) *executionContext { + executionCtx, ok := ctx.(*executionContext) + if !ok { + return nil + } + return executionCtx +} + diff --git a/server/v2/stf/gas/defaults.go b/server/v2/stf/gas/defaults.go new file mode 100644 index 000000000000..8906e31da627 --- /dev/null +++ b/server/v2/stf/gas/defaults.go @@ -0,0 +1,46 @@ +package gas + +import ( + coregas "cosmossdk.io/core/gas" + "cosmossdk.io/core/store" +) + +// DefaultWrapWithGasMeter defines the default wrap with gas meter function in stf. In case +// the meter sets as limit stf.NoGasLimit, then a fast path is taken and the store.WriterMap +// is returned. +func DefaultWrapWithGasMeter(meter coregas.Meter, state store.WriterMap) store.WriterMap { + if meter.Limit() == coregas.NoGasLimit { + return state + } + return NewMeteredWriterMap(DefaultConfig, meter, state) +} + +// DefaultGasMeter returns the default gas meter. In case it is coregas.NoGasLimit a NoOpMeter is returned. +func DefaultGasMeter(gasLimit uint64) coregas.Meter { + if gasLimit == coregas.NoGasLimit { + return NoOpMeter{} + } + return NewMeter(gasLimit) +} + +var DefaultConfig = StoreConfig{ + HasCost: 1000, + DeleteCostFlat: 1000, + ReadCostFlat: 1000, + ReadCostPerByte: 3, + WriteCostFlat: 2000, + WriteCostPerByte: 30, + IterNextCostFlat: 30, +} + +type NoOpMeter struct{} + +func (n NoOpMeter) Consumed() coregas.Gas { return 0 } + +func (n NoOpMeter) Limit() coregas.Gas { return coregas.NoGasLimit } + +func (n NoOpMeter) Consume(_ coregas.Gas, _ string) error { return nil } + +func (n NoOpMeter) Refund(_ coregas.Gas, _ string) error { return nil } + +func (n NoOpMeter) Remaining() coregas.Gas { return coregas.NoGasLimit } diff --git a/server/v2/stf/gas/meter.go b/server/v2/stf/gas/meter.go new file mode 100644 index 000000000000..9f830260b5c5 --- /dev/null +++ b/server/v2/stf/gas/meter.go @@ -0,0 +1,57 @@ +package gas + +import ( + "cosmossdk.io/core/gas" +) + +var _ gas.Meter = (*Meter)(nil) + +type Meter struct { + limit uint64 + consumed uint64 +} + +// NewMeter creates a new gas meter with the given gas limit. +// The gas meter keeps track of the gas consumed during execution. +func NewMeter(gasLimit uint64) gas.Meter { + return &Meter{ + limit: gasLimit, + consumed: 0, + } +} + +// Consumed returns the amount of gas consumed by the meter. +func (m *Meter) Consumed() gas.Gas { + return m.consumed +} + +// Limit returns the maximum gas limit allowed for the meter. +func (m *Meter) Limit() gas.Gas { + return m.limit +} + +// Consume consumes the specified amount of gas from the meter. +// It returns an error if the requested gas exceeds the remaining gas limit. +func (m *Meter) Consume(requested gas.Gas, _ string) error { + remaining := m.limit - m.consumed + if requested > remaining { + return gas.ErrOutOfGas + } + m.consumed += requested + return nil +} + +// Refund refunds the specified amount of gas. +// If the amount is less than the consumed gas, it subtracts the amount from the consumed gas. +// It returns nil error. +func (m *Meter) Refund(amount gas.Gas, _ string) error { + if amount < m.consumed { + m.consumed -= amount + } + return nil +} + +// Remaining returns the remaining gas limit. +func (m *Meter) Remaining() gas.Gas { + return m.limit - m.consumed +} diff --git a/server/v2/stf/gas/store.go b/server/v2/stf/gas/store.go new file mode 100644 index 000000000000..8ea15f12dd12 --- /dev/null +++ b/server/v2/stf/gas/store.go @@ -0,0 +1,182 @@ +package gas + +import ( + "cosmossdk.io/core/gas" + "cosmossdk.io/core/store" +) + +// Gas consumption descriptors. +const ( + DescIterNextCostFlat = "IterNextFlat" + DescValuePerByte = "ValuePerByte" + DescWritePerByte = "WritePerByte" + DescReadPerByte = "ReadPerByte" + DescWriteCostFlat = "WriteFlat" + DescReadCostFlat = "ReadFlat" + DescHas = "Has" + DescDelete = "Delete" +) + +type StoreConfig struct { + ReadCostFlat, ReadCostPerByte, HasCost gas.Gas + WriteCostFlat, WriteCostPerByte, DeleteCostFlat gas.Gas + IterNextCostFlat gas.Gas +} + +type Store struct { + parent store.Writer + gasMeter gas.Meter + gasConfig StoreConfig +} + +func NewStore(gc StoreConfig, meter gas.Meter, parent store.Writer) *Store { + return &Store{ + parent: parent, + gasMeter: meter, + gasConfig: gc, + } +} + +func (s *Store) Get(key []byte) ([]byte, error) { + if err := s.gasMeter.Consume(s.gasConfig.ReadCostFlat, DescReadCostFlat); err != nil { + return nil, err + } + + value, err := s.parent.Get(key) + if err := s.gasMeter.Consume(s.gasConfig.ReadCostPerByte*gas.Gas(len(key)), DescReadPerByte); err != nil { + return nil, err + } + if err := s.gasMeter.Consume(s.gasConfig.ReadCostPerByte*gas.Gas(len(value)), DescReadPerByte); err != nil { + return nil, err + } + + return value, err +} + +func (s *Store) Has(key []byte) (bool, error) { + if err := s.gasMeter.Consume(s.gasConfig.HasCost, DescHas); err != nil { + return false, err + } + + return s.parent.Has(key) +} + +func (s *Store) Set(key, value []byte) error { + if err := s.gasMeter.Consume(s.gasConfig.WriteCostFlat, DescWriteCostFlat); err != nil { + return err + } + if err := s.gasMeter.Consume(s.gasConfig.WriteCostPerByte*gas.Gas(len(key)), DescWritePerByte); err != nil { + return err + } + if err := s.gasMeter.Consume(s.gasConfig.WriteCostPerByte*gas.Gas(len(value)), DescWritePerByte); err != nil { + return err + } + + return s.parent.Set(key, value) +} + +func (s *Store) Delete(key []byte) error { + if err := s.gasMeter.Consume(s.gasConfig.DeleteCostFlat, DescDelete); err != nil { + return err + } + + return s.parent.Delete(key) +} + +func (s *Store) ApplyChangeSets(changes []store.KVPair) error { + return s.parent.ApplyChangeSets(changes) +} + +func (s *Store) ChangeSets() ([]store.KVPair, error) { + return s.parent.ChangeSets() +} + +func (s *Store) Iterator(start, end []byte) (store.Iterator, error) { + itr, err := s.parent.Iterator(start, end) + if err != nil { + return nil, err + } + + return newIterator(itr, s.gasMeter, s.gasConfig), nil +} + +func (s *Store) ReverseIterator(start, end []byte) (store.Iterator, error) { + itr, err := s.parent.ReverseIterator(start, end) + if err != nil { + return nil, err + } + + return newIterator(itr, s.gasMeter, s.gasConfig), nil +} + +var _ store.Iterator = (*iterator)(nil) + +type iterator struct { + gasMeter gas.Meter + gasConfig StoreConfig + parent store.Iterator +} + +func newIterator(parent store.Iterator, gm gas.Meter, gc StoreConfig) store.Iterator { + return &iterator{ + parent: parent, + gasConfig: gc, + gasMeter: gm, + } +} + +func (itr *iterator) Domain() ([]byte, []byte) { + return itr.parent.Domain() +} + +func (itr *iterator) Valid() bool { + return itr.parent.Valid() +} + +func (itr *iterator) Key() []byte { + return itr.parent.Key() +} + +func (itr *iterator) Value() []byte { + return itr.parent.Value() +} + +func (itr *iterator) Next() { + if err := itr.consumeGasSeek(); err != nil { + // closing the iterator prematurely to prevent further execution + itr.parent.Close() + return + } + itr.parent.Next() +} + +func (itr *iterator) Close() error { + return itr.parent.Close() +} + +func (itr *iterator) Error() error { + return itr.parent.Error() +} + +// consumeGasSeek consumes a fixed amount of gas for each iteration step and a +// variable gas cost based on the current key and value's length. This is called +// prior to the iterator's Next() call. +func (itr *iterator) consumeGasSeek() error { + if itr.Valid() { + key := itr.Key() + value := itr.Value() + + if err := itr.gasMeter.Consume(itr.gasConfig.ReadCostPerByte*gas.Gas(len(key)), DescValuePerByte); err != nil { + return err + } + if err := itr.gasMeter.Consume(itr.gasConfig.ReadCostPerByte*gas.Gas(len(value)), DescValuePerByte); err != nil { + return err + } + } + + if err := itr.gasMeter.Consume(itr.gasConfig.IterNextCostFlat, DescIterNextCostFlat); err != nil { + return err + } + + return nil +} diff --git a/server/v2/stf/gas/writer_map.go b/server/v2/stf/gas/writer_map.go new file mode 100644 index 000000000000..cd5fa406d175 --- /dev/null +++ b/server/v2/stf/gas/writer_map.go @@ -0,0 +1,56 @@ +package gas + +import ( + "unsafe" + + "cosmossdk.io/core/gas" + "cosmossdk.io/core/store" +) + +func NewMeteredWriterMap(conf StoreConfig, meter gas.Meter, state store.WriterMap) MeteredWriterMap { + return MeteredWriterMap{ + config: conf, + meter: meter, + state: state, + cacheMeteredStores: make(map[string]*Store), + } +} + +// MeteredWriterMap wraps store.Writer and returns a gas metered +// version of it. Since the gas meter is shared across different +// writers, the metered writers are memoized. +type MeteredWriterMap struct { + config StoreConfig + meter gas.Meter + state store.WriterMap + cacheMeteredStores map[string]*Store +} + +func (m MeteredWriterMap) GetReader(actor []byte) (store.Reader, error) { return m.GetWriter(actor) } + +func (m MeteredWriterMap) GetWriter(actor []byte) (store.Writer, error) { + cached, ok := m.cacheMeteredStores[unsafeString(actor)] + if ok { + return cached, nil + } + + state, err := m.state.GetWriter(actor) + if err != nil { + return nil, err + } + + meteredState := NewStore(m.config, m.meter, state) + m.cacheMeteredStores[string(actor)] = meteredState + + return meteredState, nil +} + +func (m MeteredWriterMap) ApplyStateChanges(stateChanges []store.StateChanges) error { + return m.state.ApplyStateChanges(stateChanges) +} + +func (m MeteredWriterMap) GetStateChanges() ([]store.StateChanges, error) { + return m.state.GetStateChanges() +} + +func unsafeString(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } diff --git a/server/v2/stf/go.mod b/server/v2/stf/go.mod new file mode 100644 index 000000000000..11ad8575e390 --- /dev/null +++ b/server/v2/stf/go.mod @@ -0,0 +1,28 @@ +module cosmossdk.io/server/v2/stf + +go 1.21 + +replace cosmossdk.io/core => ../../../core + +require ( + cosmossdk.io/core v0.11.0 + github.com/cosmos/gogoproto v1.4.12 + github.com/stretchr/testify v1.9.0 + github.com/tidwall/btree v1.7.0 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d + google.golang.org/protobuf v1.34.1 +) + +require ( + cosmossdk.io/log v1.3.1 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-cmp v0.6.0 // indirect + github.com/kr/text v0.1.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rs/zerolog v1.32.0 // indirect + golang.org/x/sys v0.19.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/server/v2/stf/go.sum b/server/v2/stf/go.sum new file mode 100644 index 000000000000..e1d82724ba8d --- /dev/null +++ b/server/v2/stf/go.sum @@ -0,0 +1,50 @@ +cosmossdk.io/log v1.3.1 h1:UZx8nWIkfbbNEWusZqzAx3ZGvu54TZacWib3EzUYmGI= +cosmossdk.io/log v1.3.1/go.mod h1:2/dIomt8mKdk6vl3OWJcPk2be3pGOS8OQaLUM/3/tCM= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/cosmos/gogoproto v1.4.12 h1:vB6Lbe/rtnYGjQuFxkPiPYiCybqFT8QvLipDZP8JpFE= +github.com/cosmos/gogoproto v1.4.12/go.mod h1:LnZob1bXRdUoqMMtwYlcR3wjiElmlC+FkjaZRv1/eLY= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0= +github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI= +golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= +google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= diff --git a/server/v2/stf/mock/db.go b/server/v2/stf/mock/db.go new file mode 100644 index 000000000000..fbc405c2b134 --- /dev/null +++ b/server/v2/stf/mock/db.go @@ -0,0 +1,38 @@ +package mock + +import ( + "cosmossdk.io/core/store" +) + +func DB() store.ReaderMap { + return actorState{kv: map[string][]byte{}} +} + +type actorState struct { + kv map[string][]byte +} + +func (m actorState) GetReader(address []byte) (store.Reader, error) { + return memState{address, m.kv}, nil +} + +type memState struct { + address []byte + kv map[string][]byte +} + +func (m memState) Has(key []byte) (bool, error) { + v, err := m.Get(key) + return v != nil, err +} + +func (m memState) Get(bytes []byte) ([]byte, error) { + key := append(m.address, bytes...) + return m.kv[string(key)], nil +} + +func (m memState) Iterator(start, end []byte) (store.Iterator, error) { panic("implement me") } + +func (m memState) ReverseIterator(start, end []byte) (store.Iterator, error) { + panic("implement me") +} diff --git a/server/v2/stf/mock/tx.go b/server/v2/stf/mock/tx.go new file mode 100644 index 000000000000..aacca5d47e4f --- /dev/null +++ b/server/v2/stf/mock/tx.go @@ -0,0 +1,108 @@ +package mock + +import ( + "crypto/sha256" + "encoding/json" + "errors" + + "google.golang.org/protobuf/types/known/anypb" + + "cosmossdk.io/core/transaction" +) + +var _ transaction.Tx = Tx{} + +type Tx struct { + Sender []byte + Msg transaction.Type + GasLimit uint64 +} + +func (t Tx) Hash() [32]byte { + return sha256.Sum256(t.Bytes()) +} + +func (t Tx) GetMessages() ([]transaction.Type, error) { + if t.Msg == nil { + return nil, errors.New("messages not available or are nil") + } + return []transaction.Type{t.Msg}, nil +} + +func (t Tx) GetSenders() ([]transaction.Identity, error) { + if t.Sender == nil { + return nil, errors.New("senders not available or are nil") + } + return []transaction.Identity{t.Sender}, nil +} + +func (t Tx) GetGasLimit() (uint64, error) { + return t.GasLimit, nil +} + +type encodedTx struct { + Sender []byte `json:"sender"` + Msg *anypb.Any `json:"message"` + GasLimit uint64 `json:"gas_limit"` +} + +func (t Tx) Bytes() []byte { + v2Msg := t.Msg + msg, err := anypb.New(v2Msg) + if err != nil { + panic(err) + } + tx, err := json.Marshal(encodedTx{ + Sender: t.Sender, + Msg: msg, + GasLimit: t.GasLimit, + }) + if err != nil { + panic(err) + } + return tx +} + +func (t *Tx) Decode(b []byte) { + rawTx := new(encodedTx) + err := json.Unmarshal(b, rawTx) + if err != nil { + panic(err) + } + msg, err := rawTx.Msg.UnmarshalNew() + if err != nil { + panic(err) + } + t.Msg = msg + t.Sender = rawTx.Sender + t.GasLimit = rawTx.GasLimit +} + +func (t *Tx) DecodeJSON(b []byte) { + rawTx := new(encodedTx) + err := json.Unmarshal(b, rawTx) + if err != nil { + panic(err) + } + msg, err := rawTx.Msg.UnmarshalNew() + if err != nil { + panic(err) + } + t.Msg = msg + t.Sender = rawTx.Sender + t.GasLimit = rawTx.GasLimit +} + +type TxCodec struct{} + +func (TxCodec) Decode(bytes []byte) (Tx, error) { + t := new(Tx) + t.Decode(bytes) + return *t, nil +} + +func (TxCodec) DecodeJSON(bytes []byte) (Tx, error) { + t := new(Tx) + t.DecodeJSON(bytes) + return *t, nil +} diff --git a/server/v2/stf/stf.go b/server/v2/stf/stf.go new file mode 100644 index 000000000000..e47829005b22 --- /dev/null +++ b/server/v2/stf/stf.go @@ -0,0 +1,641 @@ +package stf + +import ( + "context" + "errors" + "fmt" + + appmanager "cosmossdk.io/core/app" + appmodulev2 "cosmossdk.io/core/appmodule/v2" + corecontext "cosmossdk.io/core/context" + "cosmossdk.io/core/event" + "cosmossdk.io/core/gas" + "cosmossdk.io/core/header" + "cosmossdk.io/core/store" + "cosmossdk.io/core/transaction" + "cosmossdk.io/log" + stfgas "cosmossdk.io/server/v2/stf/gas" +) + +// STF is a struct that manages the state transition component of the app. +type STF[T transaction.Tx] struct { + logger log.Logger + handleMsg func(ctx context.Context, msg transaction.Type) (transaction.Type, error) + handleQuery func(ctx context.Context, req transaction.Type) (transaction.Type, error) + + doPreBlock func(ctx context.Context, txs []T) error + doBeginBlock func(ctx context.Context) error + doEndBlock func(ctx context.Context) error + doValidatorUpdate func(ctx context.Context) ([]appmodulev2.ValidatorUpdate, error) + + doTxValidation func(ctx context.Context, tx T) error + postTxExec func(ctx context.Context, tx T, success bool) error + + branchFn branchFn // branchFn is a function that given a readonly state it returns a writable version of it. + makeGasMeter makeGasMeterFn + makeGasMeteredState makeGasMeteredStateFn +} + +// NewSTF returns a new STF instance. +func NewSTF[T transaction.Tx]( + handleMsg func(ctx context.Context, msg transaction.Type) (transaction.Type, error), + handleQuery func(ctx context.Context, req transaction.Type) (transaction.Type, error), + doPreBlock func(ctx context.Context, txs []T) error, + doBeginBlock func(ctx context.Context) error, + doEndBlock func(ctx context.Context) error, + doTxValidation func(ctx context.Context, tx T) error, + doValidatorUpdate func(ctx context.Context) ([]appmodulev2.ValidatorUpdate, error), + postTxExec func(ctx context.Context, tx T, success bool) error, + branch func(store store.ReaderMap) store.WriterMap, +) *STF[T] { + return &STF[T]{ + handleMsg: handleMsg, + handleQuery: handleQuery, + doPreBlock: doPreBlock, + doBeginBlock: doBeginBlock, + doEndBlock: doEndBlock, + doTxValidation: doTxValidation, + doValidatorUpdate: doValidatorUpdate, + postTxExec: postTxExec, // TODO + branchFn: branch, + makeGasMeter: stfgas.DefaultGasMeter, + makeGasMeteredState: stfgas.DefaultWrapWithGasMeter, + } +} + +// DeliverBlock is our state transition function. +// It takes a read only view of the state to apply the block to, +// executes the block and returns the block results and the new state. +func (s STF[T]) DeliverBlock( + ctx context.Context, + block *appmanager.BlockRequest[T], + state store.ReaderMap, +) (blockResult *appmanager.BlockResponse, newState store.WriterMap, err error) { + // creates a new branchFn state, from the readonly view of the state + // that can be written to. + newState = s.branchFn(state) + hi := header.Info{ + Hash: block.Hash, + AppHash: block.AppHash, + ChainID: block.ChainId, + Time: block.Time, + Height: int64(block.Height), + } + // set header info + err = s.setHeaderInfo(newState, hi) + if err != nil { + return nil, nil, fmt.Errorf("unable to set initial header info, %w", err) + } + + exCtx := s.makeContext(ctx, appmanager.ConsensusIdentity, newState, corecontext.ExecModeFinalize) + exCtx.setHeaderInfo(hi) + consMessagesResponses, err := s.runConsensusMessages(exCtx, block.ConsensusMessages) + if err != nil { + return nil, nil, fmt.Errorf("failed to execute consensus messages: %w", err) + } + + // reset events + exCtx.events = make([]event.Event, 0) + // pre block is called separate from begin block in order to prepopulate state + preBlockEvents, err := s.preBlock(exCtx, block.Txs) + if err != nil { + return nil, nil, err + } + + if err = isCtxCancelled(ctx); err != nil { + return nil, nil, err + } + + // reset events + exCtx.events = make([]event.Event, 0) + // begin block + beginBlockEvents, err := s.beginBlock(exCtx) + if err != nil { + return nil, nil, err + } + + // check if we need to return early + if err = isCtxCancelled(ctx); err != nil { + return nil, nil, err + } + + // execute txs + txResults := make([]appmanager.TxResult, len(block.Txs)) + // TODO: skip first tx if vote extensions are enabled (marko) + for i, txBytes := range block.Txs { + // check if we need to return early or continue delivering txs + if err = isCtxCancelled(ctx); err != nil { + return nil, nil, err + } + txResults[i] = s.deliverTx(ctx, newState, txBytes, corecontext.ExecModeFinalize, hi) + } + // reset events + exCtx.events = make([]event.Event, 0) + // end block + endBlockEvents, valset, err := s.endBlock(exCtx) + if err != nil { + return nil, nil, err + } + + return &appmanager.BlockResponse{ + Apphash: nil, + ConsensusMessagesResponse: consMessagesResponses, + ValidatorUpdates: valset, + PreBlockEvents: preBlockEvents, + BeginBlockEvents: beginBlockEvents, + TxResults: txResults, + EndBlockEvents: endBlockEvents, + }, newState, nil +} + +// deliverTx executes a TX and returns the result. +func (s STF[T]) deliverTx( + ctx context.Context, + state store.WriterMap, + tx T, + execMode corecontext.ExecMode, + hi header.Info, +) appmanager.TxResult { + // recover in the case of a panic + var recoveryError error + defer func() { + if r := recover(); r != nil { + recoveryError = fmt.Errorf("panic during transaction execution: %s", r) + s.logger.Error("panic during transaction execution", "error", recoveryError) + } + }() + // handle error from GetGasLimit + gasLimit, gasLimitErr := tx.GetGasLimit() + if gasLimitErr != nil { + return appmanager.TxResult{ + Error: gasLimitErr, + } + } + + if recoveryError != nil { + return appmanager.TxResult{ + Error: recoveryError, + } + } + + validateGas, validationEvents, err := s.validateTx(ctx, state, gasLimit, tx) + if err != nil { + return appmanager.TxResult{ + Error: err, + } + } + + execResp, execGas, execEvents, err := s.execTx(ctx, state, gasLimit-validateGas, tx, execMode, hi) + return appmanager.TxResult{ + Events: append(validationEvents, execEvents...), + GasUsed: execGas + validateGas, + GasWanted: gasLimit, + Resp: execResp, + Error: err, + } +} + +// validateTx validates a transaction given the provided WritableState and gas limit. +// If the validation is successful, state is committed +func (s STF[T]) validateTx( + ctx context.Context, + state store.WriterMap, + gasLimit uint64, + tx T, +) (gasUsed uint64, events []event.Event, err error) { + validateState := s.branchFn(state) + hi, err := s.getHeaderInfo(validateState) + if err != nil { + return 0, nil, err + } + validateCtx := s.makeContext(ctx, appmanager.RuntimeIdentity, validateState, corecontext.ExecModeCheck) + validateCtx.setHeaderInfo(hi) + validateCtx.setGasLimit(gasLimit) + err = s.doTxValidation(validateCtx, tx) + if err != nil { + return 0, nil, err + } + + consumed := validateCtx.meter.Limit() - validateCtx.meter.Remaining() + + return consumed, validateCtx.events, applyStateChanges(state, validateState) +} + +// execTx executes the tx messages on the provided state. If the tx fails then the state is discarded. +func (s STF[T]) execTx( + ctx context.Context, + state store.WriterMap, + gasLimit uint64, + tx T, + execMode corecontext.ExecMode, + hi header.Info, +) ([]transaction.Type, uint64, []event.Event, error) { + execState := s.branchFn(state) + + msgsResp, gasUsed, runTxMsgsEvents, txErr := s.runTxMsgs(ctx, execState, gasLimit, tx, execMode, hi) + if txErr != nil { + // in case of error during message execution, we do not apply the exec state. + // instead we run the post exec handler in a new branchFn from the initial state. + postTxState := s.branchFn(state) + postTxCtx := s.makeContext(ctx, appmanager.RuntimeIdentity, postTxState, execMode) + postTxCtx.setHeaderInfo(hi) + + // TODO: runtime sets a noop posttxexec if the app doesnt set anything (julien) + + postTxErr := s.postTxExec(postTxCtx, tx, false) + if postTxErr != nil { + // if the post tx handler fails, then we do not apply any state change to the initial state. + // we just return the exec gas used and a joined error from TX error and post TX error. + return nil, gasUsed, nil, errors.Join(txErr, postTxErr) + } + // in case post tx is successful, then we commit the post tx state to the initial state, + // and we return post tx events alongside exec gas used and the error of the tx. + applyErr := applyStateChanges(state, postTxState) + if applyErr != nil { + return nil, 0, nil, applyErr + } + return nil, gasUsed, postTxCtx.events, txErr + } + // tx execution went fine, now we use the same state to run the post tx exec handler, + // in case the execution of the post tx fails, then no state change is applied and the + // whole execution step is rolled back. + postTxCtx := s.makeContext(ctx, appmanager.RuntimeIdentity, execState, execMode) // NO gas limit. + postTxCtx.setHeaderInfo(hi) + postTxErr := s.postTxExec(postTxCtx, tx, true) + if postTxErr != nil { + // if post tx fails, then we do not apply any state change, we return the post tx error, + // alongside the gas used. + return nil, gasUsed, nil, postTxErr + } + // both the execution and post tx execution step were successful, so we apply the state changes + // to the provided state, and we return responses, and events from exec tx and post tx exec. + applyErr := applyStateChanges(state, execState) + if applyErr != nil { + return nil, 0, nil, applyErr + } + + return msgsResp, gasUsed, append(runTxMsgsEvents, postTxCtx.events...), nil +} + +// runTxMsgs will execute the messages contained in the TX with the provided state. +func (s STF[T]) runTxMsgs( + ctx context.Context, + state store.WriterMap, + gasLimit uint64, + tx T, + execMode corecontext.ExecMode, + hi header.Info, +) ([]transaction.Type, uint64, []event.Event, error) { + txSenders, err := tx.GetSenders() + if err != nil { + return nil, 0, nil, err + } + msgs, err := tx.GetMessages() + if err != nil { + return nil, 0, nil, err + } + msgResps := make([]transaction.Type, len(msgs)) + + execCtx := s.makeContext(ctx, nil, state, execMode) + execCtx.setHeaderInfo(hi) + execCtx.setGasLimit(gasLimit) + for i, msg := range msgs { + execCtx.sender = txSenders[i] + resp, err := s.handleMsg(execCtx, msg) + if err != nil { + return nil, 0, nil, fmt.Errorf("message execution at index %d failed: %w", i, err) + } + msgResps[i] = resp + } + + consumed := execCtx.meter.Limit() - execCtx.meter.Remaining() + return msgResps, consumed, execCtx.events, nil +} + +func (s STF[T]) preBlock( + ctx *executionContext, + txs []T, +) ([]event.Event, error) { + err := s.doPreBlock(ctx, txs) + if err != nil { + return nil, err + } + + for i, e := range ctx.events { + ctx.events[i].Attributes = append( + e.Attributes, + event.Attribute{Key: "mode", Value: "PreBlock"}, + ) + } + + return ctx.events, nil +} + +func (s STF[T]) runConsensusMessages( + ctx *executionContext, + messages []transaction.Type, +) ([]transaction.Type, error) { + responses := make([]transaction.Type, len(messages)) + for i := range messages { + resp, err := s.handleMsg(ctx, messages[i]) + if err != nil { + return nil, err + } + responses[i] = resp + } + + return responses, nil +} + +func (s STF[T]) beginBlock( + ctx *executionContext, +) (beginBlockEvents []event.Event, err error) { + err = s.doBeginBlock(ctx) + if err != nil { + return nil, err + } + + for i, e := range ctx.events { + ctx.events[i].Attributes = append( + e.Attributes, + event.Attribute{Key: "mode", Value: "BeginBlock"}, + ) + } + + return ctx.events, nil +} + +func (s STF[T]) endBlock( + ctx *executionContext, +) ([]event.Event, []appmodulev2.ValidatorUpdate, error) { + err := s.doEndBlock(ctx) + if err != nil { + return nil, nil, err + } + + events, valsetUpdates, err := s.validatorUpdates(ctx) + if err != nil { + return nil, nil, err + } + + ctx.events = append(ctx.events, events...) + + for i, e := range ctx.events { + ctx.events[i].Attributes = append( + e.Attributes, + event.Attribute{Key: "mode", Value: "BeginBlock"}, + ) + } + + return ctx.events, valsetUpdates, nil +} + +// validatorUpdates returns the validator updates for the current block. It is called by endBlock after the endblock execution has concluded +func (s STF[T]) validatorUpdates( + ctx *executionContext, +) ([]event.Event, []appmodulev2.ValidatorUpdate, error) { + valSetUpdates, err := s.doValidatorUpdate(ctx) + if err != nil { + return nil, nil, err + } + return ctx.events, valSetUpdates, nil +} + +const headerInfoPrefix = 0x0 + +// setHeaderInfo sets the header info in the state to be used by queries in the future. +func (s STF[T]) setHeaderInfo(state store.WriterMap, headerInfo header.Info) error { + runtimeStore, err := state.GetWriter(appmanager.RuntimeIdentity) + if err != nil { + return err + } + bz, err := headerInfo.Bytes() + if err != nil { + return err + } + err = runtimeStore.Set([]byte{headerInfoPrefix}, bz) + if err != nil { + return err + } + return nil +} + +// getHeaderInfo gets the header info from the state. It should only be used for queries +func (s STF[T]) getHeaderInfo(state store.WriterMap) (i header.Info, err error) { + runtimeStore, err := state.GetWriter(appmanager.RuntimeIdentity) + if err != nil { + return header.Info{}, err + } + v, err := runtimeStore.Get([]byte{headerInfoPrefix}) + if err != nil { + return header.Info{}, err + } + if v == nil { + return header.Info{}, nil + } + + err = i.FromBytes(v) + return i, err +} + +// Simulate simulates the execution of a tx on the provided state. +func (s STF[T]) Simulate( + ctx context.Context, + state store.ReaderMap, + gasLimit uint64, + tx T, +) (appmanager.TxResult, store.WriterMap) { + simulationState := s.branchFn(state) + hi, err := s.getHeaderInfo(simulationState) + if err != nil { + return appmanager.TxResult{}, nil + } + txr := s.deliverTx(ctx, simulationState, tx, corecontext.ExecModeSimulate, hi) + + return txr, simulationState +} + +// ValidateTx will run only the validation steps required for a transaction. +// Validations are run over the provided state, with the provided gas limit. +func (s STF[T]) ValidateTx( + ctx context.Context, + state store.ReaderMap, + gasLimit uint64, + tx T, +) appmanager.TxResult { + validationState := s.branchFn(state) + gasUsed, events, err := s.validateTx(ctx, validationState, gasLimit, tx) + return appmanager.TxResult{ + Events: events, + GasUsed: gasUsed, + Error: err, + } +} + +// Query executes the query on the provided state with the provided gas limits. +func (s STF[T]) Query( + ctx context.Context, + state store.ReaderMap, + gasLimit uint64, + req transaction.Type, +) (transaction.Type, error) { + queryState := s.branchFn(state) + hi, err := s.getHeaderInfo(queryState) + if err != nil { + return nil, err + } + queryCtx := s.makeContext(ctx, nil, queryState, corecontext.ExecModeSimulate) + queryCtx.setHeaderInfo(hi) + queryCtx.setGasLimit(gasLimit) + return s.handleQuery(queryCtx, req) +} + +func (s STF[T]) Message(ctx context.Context, msg transaction.Type) (response transaction.Type, err error) { + return s.handleMsg(ctx, msg) +} + +// RunWithCtx is made to support genesis, if genesis was just the execution of messages instead +// of being something custom then we would not need this. PLEASE DO NOT USE. +// TODO: Remove +func (s STF[T]) RunWithCtx( + ctx context.Context, + state store.ReaderMap, + closure func(ctx context.Context) error, +) (store.WriterMap, error) { + branchedState := s.branchFn(state) + // TODO do we need headerinfo for genesis? + stfCtx := s.makeContext(ctx, nil, branchedState, corecontext.ExecModeFinalize) + return branchedState, closure(stfCtx) +} + +// clone clones STF. +func (s STF[T]) clone() STF[T] { + return STF[T]{ + handleMsg: s.handleMsg, + handleQuery: s.handleQuery, + doPreBlock: s.doPreBlock, + doBeginBlock: s.doBeginBlock, + doEndBlock: s.doEndBlock, + doValidatorUpdate: s.doValidatorUpdate, + doTxValidation: s.doTxValidation, + postTxExec: s.postTxExec, + branchFn: s.branchFn, + makeGasMeter: s.makeGasMeter, + makeGasMeteredState: s.makeGasMeteredState, + } +} + +// executionContext is a struct that holds the context for the execution of a tx. +type executionContext struct { + context.Context + + // unmeteredState is storage without metering. Changes here are propagated to state which is the metered + // version. + unmeteredState store.WriterMap + // state is the gas metered state. + state store.WriterMap + // meter is the gas meter. + meter gas.Meter + // events are the current events. + events []event.Event + // sender is the causer of the state transition. + sender transaction.Identity + // headerInfo contains the block info. + headerInfo header.Info + // execMode retains information about the exec mode. + execMode corecontext.ExecMode + + branchFn branchFn + makeGasMeter makeGasMeterFn + makeGasMeteredStore makeGasMeteredStateFn +} + +// setHeaderInfo sets the header info in the state to be used by queries in the future. +func (e *executionContext) setHeaderInfo(hi header.Info) { + e.headerInfo = hi +} + +// setGasLimit will update the gas limit of the *executionContext +func (e *executionContext) setGasLimit(limit uint64) { + meter := e.makeGasMeter(limit) + meteredState := e.makeGasMeteredStore(meter, e.unmeteredState) + + e.meter = meter + e.state = meteredState +} + +// TODO: too many calls to makeContext can be expensive +// makeContext creates and returns a new execution context for the STF[T] type. +// It takes in the following parameters: +// - ctx: The context.Context object for the execution. +// - sender: The transaction.Identity object representing the sender of the transaction. +// - state: The store.WriterMap object for accessing and modifying the state. +// - gasLimit: The maximum amount of gas allowed for the execution. +// - execMode: The corecontext.ExecMode object representing the execution mode. +// +// It returns a pointer to the executionContext struct +func (s STF[T]) makeContext( + ctx context.Context, + sender transaction.Identity, + store store.WriterMap, + execMode corecontext.ExecMode, +) *executionContext { + return newExecutionContext( + s.makeGasMeter, + s.makeGasMeteredState, + s.branchFn, + ctx, + sender, + store, + execMode, + ) +} + +func newExecutionContext( + makeGasMeterFn makeGasMeterFn, + makeGasMeteredStoreFn makeGasMeteredStateFn, + branchFn branchFn, + ctx context.Context, + sender transaction.Identity, + state store.WriterMap, + execMode corecontext.ExecMode, +) *executionContext { + meter := makeGasMeterFn(gas.NoGasLimit) + meteredState := makeGasMeteredStoreFn(meter, state) + + return &executionContext{ + Context: ctx, + unmeteredState: state, + state: meteredState, + meter: meter, + events: make([]event.Event, 0), + sender: sender, + headerInfo: header.Info{}, + execMode: execMode, + branchFn: branchFn, + makeGasMeter: makeGasMeterFn, + makeGasMeteredStore: makeGasMeteredStoreFn, + } +} + +// applyStateChanges applies the state changes from the source store to the destination store. +// It retrieves the state changes from the source store using GetStateChanges method, +// and then applies those changes to the destination store using ApplyStateChanges method. +// If an error occurs during the retrieval or application of state changes, it is returned. +func applyStateChanges(dst, src store.WriterMap) error { + changes, err := src.GetStateChanges() + if err != nil { + return err + } + return dst.ApplyStateChanges(changes) +} + +// isCtxCancelled reports if the context was canceled. +func isCtxCancelled(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} diff --git a/server/v2/stf/stf_router.go b/server/v2/stf/stf_router.go new file mode 100644 index 000000000000..a8809d231299 --- /dev/null +++ b/server/v2/stf/stf_router.go @@ -0,0 +1,144 @@ +package stf + +import ( + "context" + "errors" + "fmt" + + gogoproto "github.com/cosmos/gogoproto/proto" + "google.golang.org/protobuf/proto" + + appmodulev2 "cosmossdk.io/core/appmodule/v2" +) + +var ErrNoHandler = errors.New("no handler") + +// NewMsgRouterBuilder is a router that routes messages to their respective handlers. +func NewMsgRouterBuilder() *MsgRouterBuilder { + return &MsgRouterBuilder{ + handlers: make(map[string]appmodulev2.Handler), + preHandlers: make(map[string][]appmodulev2.PreMsgHandler), + postHandlers: make(map[string][]appmodulev2.PostMsgHandler), + } +} + +type MsgRouterBuilder struct { + handlers map[string]appmodulev2.Handler + globalPreHandlers []appmodulev2.PreMsgHandler + preHandlers map[string][]appmodulev2.PreMsgHandler + postHandlers map[string][]appmodulev2.PostMsgHandler + globalPostHandlers []appmodulev2.PostMsgHandler +} + +func (b *MsgRouterBuilder) RegisterHandler(msgType string, handler appmodulev2.Handler) error { + // panic on override + if _, ok := b.handlers[msgType]; ok { + return fmt.Errorf("handler already registered: %s", msgType) + } + b.handlers[msgType] = handler + return nil +} + +func (b *MsgRouterBuilder) RegisterGlobalPreHandler(handler appmodulev2.PreMsgHandler) { + b.globalPreHandlers = append(b.globalPreHandlers, handler) +} + +func (b *MsgRouterBuilder) RegisterPreHandler(msgType string, handler appmodulev2.PreMsgHandler) { + b.preHandlers[msgType] = append(b.preHandlers[msgType], handler) +} + +func (b *MsgRouterBuilder) RegisterPostHandler(msgType string, handler appmodulev2.PostMsgHandler) { + b.postHandlers[msgType] = append(b.postHandlers[msgType], handler) +} + +func (b *MsgRouterBuilder) RegisterGlobalPostHandler(handler appmodulev2.PostMsgHandler) { + b.globalPostHandlers = append(b.globalPostHandlers, handler) +} + +func (b *MsgRouterBuilder) Build() (appmodulev2.Handler, error) { + handlers := make(map[string]appmodulev2.Handler) + + globalPreHandler := func(ctx context.Context, msg appmodulev2.Message) error { + for _, h := range b.globalPreHandlers { + err := h(ctx, msg) + if err != nil { + return err + } + } + return nil + } + + globalPostHandler := func(ctx context.Context, msg, msgResp appmodulev2.Message) error { + for _, h := range b.globalPostHandlers { + err := h(ctx, msg, msgResp) + if err != nil { + return err + } + } + return nil + } + + for msgType, handler := range b.handlers { + // find pre handler + preHandlers := b.preHandlers[msgType] + // find post handler + postHandlers := b.postHandlers[msgType] + // build the handler + handlers[msgType] = buildHandler(handler, preHandlers, globalPreHandler, postHandlers, globalPostHandler) + } + + // return handler as function + return func(ctx context.Context, msg appmodulev2.Message) (appmodulev2.Message, error) { + typeName := msgTypeURL(msg) + handler, exists := handlers[typeName] + if !exists { + return nil, fmt.Errorf("%w: %s", ErrNoHandler, typeName) + } + return handler(ctx, msg) + }, nil +} + +func buildHandler( + handler appmodulev2.Handler, + preHandlers []appmodulev2.PreMsgHandler, + globalPreHandler appmodulev2.PreMsgHandler, + postHandlers []appmodulev2.PostMsgHandler, + globalPostHandler appmodulev2.PostMsgHandler, +) appmodulev2.Handler { + return func(ctx context.Context, msg appmodulev2.Message) (msgResp appmodulev2.Message, err error) { + if len(preHandlers) != 0 { + for _, preHandler := range preHandlers { + if err := preHandler(ctx, msg); err != nil { + return nil, err + } + } + } + err = globalPreHandler(ctx, msg) + if err != nil { + return nil, err + } + msgResp, err = handler(ctx, msg) + if err != nil { + return nil, err + } + + if len(postHandlers) != 0 { + for _, postHandler := range postHandlers { + if err := postHandler(ctx, msg, msgResp); err != nil { + return nil, err + } + } + } + err = globalPostHandler(ctx, msg, msgResp) + return msgResp, err + } +} + +// msgTypeURL returns the TypeURL of a proto message. +func msgTypeURL(msg gogoproto.Message) string { + if m, ok := msg.(proto.Message); ok { + return string(m.ProtoReflect().Descriptor().FullName()) + } + + return gogoproto.MessageName(msg) +} diff --git a/server/v2/stf/stf_test.go b/server/v2/stf/stf_test.go new file mode 100644 index 000000000000..66a833c0e247 --- /dev/null +++ b/server/v2/stf/stf_test.go @@ -0,0 +1,235 @@ +package stf + +import ( + "context" + "crypto/sha256" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" + + appmanager "cosmossdk.io/core/app" + appmodulev2 "cosmossdk.io/core/appmodule/v2" + coregas "cosmossdk.io/core/gas" + "cosmossdk.io/core/store" + "cosmossdk.io/core/transaction" + "cosmossdk.io/server/v2/stf/branch" + "cosmossdk.io/server/v2/stf/gas" + "cosmossdk.io/server/v2/stf/mock" +) + +func TestSTF(t *testing.T) { + state := mock.DB() + mockTx := mock.Tx{ + Sender: []byte("sender"), + Msg: wrapperspb.Bool(true), // msg does not matter at all because our handler does nothing. + GasLimit: 100_000, + } + + sum := sha256.Sum256([]byte("test-hash")) + + s := &STF[mock.Tx]{ + handleMsg: func(ctx context.Context, msg transaction.Type) (msgResp transaction.Type, err error) { + kvSet(t, ctx, "exec") + return nil, nil + }, + handleQuery: nil, + doPreBlock: func(ctx context.Context, txs []mock.Tx) error { return nil }, + doBeginBlock: func(ctx context.Context) error { + kvSet(t, ctx, "begin-block") + return nil + }, + doEndBlock: func(ctx context.Context) error { + kvSet(t, ctx, "end-block") + return nil + }, + doValidatorUpdate: func(ctx context.Context) ([]appmodulev2.ValidatorUpdate, error) { return nil, nil }, + doTxValidation: func(ctx context.Context, tx mock.Tx) error { + kvSet(t, ctx, "validate") + return nil + }, + postTxExec: func(ctx context.Context, tx mock.Tx, success bool) error { + kvSet(t, ctx, "post-tx-exec") + return nil + }, + branchFn: branch.DefaultNewWriterMap, + makeGasMeter: gas.DefaultGasMeter, + makeGasMeteredState: gas.DefaultWrapWithGasMeter, + } + + t.Run("begin and end block", func(t *testing.T) { + _, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ + Height: uint64(1), + Time: time.Date(2024, 2, 3, 18, 23, 0, 0, time.UTC), + AppHash: sum[:], + Hash: sum[:], + }, state) + require.NoError(t, err) + stateHas(t, newState, "begin-block") + stateHas(t, newState, "end-block") + }) + + t.Run("basic tx", func(t *testing.T) { + result, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ + Height: uint64(1), + Time: time.Date(2024, 2, 3, 18, 23, 0, 0, time.UTC), + AppHash: sum[:], + Hash: sum[:], + Txs: []mock.Tx{mockTx}, + }, state) + require.NoError(t, err) + stateHas(t, newState, "validate") + stateHas(t, newState, "exec") + stateHas(t, newState, "post-tx-exec") + + require.Len(t, result.TxResults, 1) + txResult := result.TxResults[0] + require.NotZero(t, txResult.GasUsed) + require.Equal(t, mockTx.GasLimit, txResult.GasWanted) + }) + + t.Run("exec tx out of gas", func(t *testing.T) { + s := s.clone() + + mockTx := mock.Tx{ + Sender: []byte("sender"), + Msg: wrapperspb.Bool(true), // msg does not matter at all because our handler does nothing. + GasLimit: 0, // NO GAS! + } + + // this handler will propagate the storage error back, we expect + // out of gas immediately at tx validation level. + s.doTxValidation = func(ctx context.Context, tx mock.Tx) error { + w, err := ctx.(*executionContext).state.GetWriter(actorName) + require.NoError(t, err) + err = w.Set([]byte("gas_failure"), []byte{}) + require.Error(t, err) + return err + } + + result, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ + Height: uint64(1), + Time: time.Date(2024, 2, 3, 18, 23, 0, 0, time.UTC), + AppHash: sum[:], + Hash: sum[:], + Txs: []mock.Tx{mockTx}, + }, state) + require.NoError(t, err) + stateNotHas(t, newState, "gas_failure") // assert during out of gas no state changes leaked. + require.ErrorIs(t, result.TxResults[0].Error, coregas.ErrOutOfGas, result.TxResults[0].Error) + }) + + t.Run("fail exec tx", func(t *testing.T) { + // update the stf to fail on the handler + s := s.clone() + s.handleMsg = func(ctx context.Context, msg transaction.Type) (msgResp transaction.Type, err error) { + return nil, fmt.Errorf("failure") + } + + blockResult, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ + Height: uint64(1), + Time: time.Date(2024, 2, 3, 18, 23, 0, 0, time.UTC), + AppHash: sum[:], + Hash: sum[:], + Txs: []mock.Tx{mockTx}, + }, state) + require.NoError(t, err) + require.ErrorContains(t, blockResult.TxResults[0].Error, "failure") + stateHas(t, newState, "begin-block") + stateHas(t, newState, "end-block") + stateHas(t, newState, "validate") + stateNotHas(t, newState, "exec") + stateHas(t, newState, "post-tx-exec") + }) + + t.Run("tx is success but post tx failed", func(t *testing.T) { + s := s.clone() + s.postTxExec = func(ctx context.Context, tx mock.Tx, success bool) error { + return fmt.Errorf("post tx failure") + } + blockResult, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ + Height: uint64(1), + Time: time.Date(2024, 2, 3, 18, 23, 0, 0, time.UTC), + AppHash: sum[:], + Hash: sum[:], + Txs: []mock.Tx{mockTx}, + }, state) + require.NoError(t, err) + require.ErrorContains(t, blockResult.TxResults[0].Error, "post tx failure") + stateHas(t, newState, "begin-block") + stateHas(t, newState, "end-block") + stateHas(t, newState, "validate") + stateNotHas(t, newState, "exec") + stateNotHas(t, newState, "post-tx-exec") + }) + + t.Run("tx failed and post tx failed", func(t *testing.T) { + s := s.clone() + s.handleMsg = func(ctx context.Context, msg transaction.Type) (msgResp transaction.Type, err error) { + return nil, fmt.Errorf("exec failure") + } + s.postTxExec = func(ctx context.Context, tx mock.Tx, success bool) error { return fmt.Errorf("post tx failure") } + blockResult, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ + Height: uint64(1), + Time: time.Date(2024, 2, 3, 18, 23, 0, 0, time.UTC), + AppHash: sum[:], + Hash: sum[:], + Txs: []mock.Tx{mockTx}, + }, state) + require.NoError(t, err) + require.ErrorContains(t, blockResult.TxResults[0].Error, "exec failure\npost tx failure") + stateHas(t, newState, "begin-block") + stateHas(t, newState, "end-block") + stateHas(t, newState, "validate") + stateNotHas(t, newState, "exec") + stateNotHas(t, newState, "post-tx-exec") + }) + + t.Run("fail validate tx", func(t *testing.T) { + // update stf to fail on the validation step + s := s.clone() + s.doTxValidation = func(ctx context.Context, tx mock.Tx) error { return fmt.Errorf("failure") } + blockResult, newState, err := s.DeliverBlock(context.Background(), &appmanager.BlockRequest[mock.Tx]{ + Height: uint64(1), + Time: time.Date(2024, 2, 3, 18, 23, 0, 0, time.UTC), + AppHash: sum[:], + Hash: sum[:], + Txs: []mock.Tx{mockTx}, + }, state) + require.NoError(t, err) + require.ErrorContains(t, blockResult.TxResults[0].Error, "failure") + stateHas(t, newState, "begin-block") + stateHas(t, newState, "end-block") + stateNotHas(t, newState, "validate") + stateNotHas(t, newState, "exec") + }) +} + +var actorName = []byte("cookies") + +func kvSet(t *testing.T, ctx context.Context, v string) { + t.Helper() + state, err := ctx.(*executionContext).state.GetWriter(actorName) + require.NoError(t, err) + require.NoError(t, state.Set([]byte(v), []byte(v))) +} + +func stateHas(t *testing.T, accountState store.ReaderMap, key string) { + t.Helper() + state, err := accountState.GetReader(actorName) + require.NoError(t, err) + has, err := state.Has([]byte(key)) + require.NoError(t, err) + require.Truef(t, has, "state did not have key: %s", key) +} + +func stateNotHas(t *testing.T, accountState store.ReaderMap, key string) { + t.Helper() + state, err := accountState.GetReader(actorName) + require.NoError(t, err) + has, err := state.Has([]byte(key)) + require.NoError(t, err) + require.Falsef(t, has, "state was not supposed to have key: %s", key) +}