diff --git a/dot/rpc/subscription/listeners_test.go b/dot/rpc/subscription/listeners_test.go index fb1c10788e..ae991a2aa7 100644 --- a/dot/rpc/subscription/listeners_test.go +++ b/dot/rpc/subscription/listeners_test.go @@ -364,7 +364,7 @@ func TestRuntimeChannelListener_Listen(t *testing.T) { expectedInitialResponse.Params.Result = expectedInitialVersion instance := wasmer.NewTestInstance(t, runtime.NODE_RUNTIME) - _, err := runtime.GetRuntimeBlob(runtime.POLKADOT_RUNTIME_FP, runtime.POLKADOT_RUNTIME_URL) + err := runtime.GetRuntimeBlob(runtime.POLKADOT_RUNTIME_FP, runtime.POLKADOT_RUNTIME_URL) require.NoError(t, err) fp, err := filepath.Abs(runtime.POLKADOT_RUNTIME_FP) require.NoError(t, err) diff --git a/lib/runtime/life/test_helpers.go b/lib/runtime/life/test_helpers.go index 853bd4f530..0f58c8a249 100644 --- a/lib/runtime/life/test_helpers.go +++ b/lib/runtime/life/test_helpers.go @@ -48,7 +48,7 @@ func NewTestInstanceWithTrie(t *testing.T, targetRuntime string, tt *trie.Trie, func setupConfig(t *testing.T, targetRuntime string, tt *trie.Trie, lvl log.Lvl, role byte) (string, *Config) { testRuntimeFilePath, testRuntimeURL := runtime.GetRuntimeVars(targetRuntime) - _, err := runtime.GetRuntimeBlob(testRuntimeFilePath, testRuntimeURL) + err := runtime.GetRuntimeBlob(testRuntimeFilePath, testRuntimeURL) require.Nil(t, err, "Fail: could not get runtime", "targetRuntime", targetRuntime) s, err := storage.NewTrieState(tt) diff --git a/lib/runtime/test_helpers.go b/lib/runtime/test_helpers.go index 59da6fd68d..c513aa4bdf 100644 --- a/lib/runtime/test_helpers.go +++ b/lib/runtime/test_helpers.go @@ -17,12 +17,13 @@ package runtime import ( - "io" + "context" "io/ioutil" "net/http" "os" "path" "testing" + "time" "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/common" @@ -80,42 +81,41 @@ func GetAbsolutePath(targetDir string) string { } // GetRuntimeBlob checks if the test wasm @testRuntimeFilePath exists and if not, it fetches it from @testRuntimeURL -func GetRuntimeBlob(testRuntimeFilePath, testRuntimeURL string) (n int64, err error) { +func GetRuntimeBlob(testRuntimeFilePath, testRuntimeURL string) error { if utils.PathExists(testRuntimeFilePath) { - return 0, nil + return nil } - out, err := os.Create(testRuntimeFilePath) - if err != nil { - return 0, err - } + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - /* #nosec */ - resp, err := http.Get(testRuntimeURL) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, testRuntimeURL, nil) if err != nil { - return 0, err + return err } - defer func() { - _ = resp.Body.Close() - }() - n, err = io.Copy(out, resp.Body) + const runtimeReqTimout = time.Second * 30 + + httpcli := http.Client{Timeout: runtimeReqTimout} + resp, err := httpcli.Do(req) if err != nil { - return 0, err + return err } - if err = out.Close(); err != nil { - return 0, err + respBody, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err } + defer resp.Body.Close() //nolint:errcheck - return n, nil + return ioutil.WriteFile(testRuntimeFilePath, respBody, os.ModePerm) } // TestRuntimeNetwork ... type TestRuntimeNetwork struct{} // NetworkState ... -func (trn *TestRuntimeNetwork) NetworkState() common.NetworkState { +func (*TestRuntimeNetwork) NetworkState() common.NetworkState { testAddrs := []ma.Multiaddr(nil) // create mock multiaddress @@ -155,11 +155,12 @@ func GenerateRuntimeWasmFile() ([]string, error) { var wasmFilePaths []string for _, rt := range runtimes { testRuntimeFilePath, testRuntimeURL := GetRuntimeVars(rt) - wasmFilePaths = append(wasmFilePaths, testRuntimeFilePath) - _, err := GetRuntimeBlob(testRuntimeFilePath, testRuntimeURL) + err := GetRuntimeBlob(testRuntimeFilePath, testRuntimeURL) if err != nil { return nil, err } + + wasmFilePaths = append(wasmFilePaths, testRuntimeFilePath) } return wasmFilePaths, nil } diff --git a/lib/runtime/wasmer/imports.go b/lib/runtime/wasmer/imports.go index f75a1ee4c7..39e3944238 100644 --- a/lib/runtime/wasmer/imports.go +++ b/lib/runtime/wasmer/imports.go @@ -895,10 +895,37 @@ func ext_trie_blake2_256_ordered_root_version_1(context unsafe.Pointer, dataSpan } //export ext_trie_blake2_256_verify_proof_version_1 -func ext_trie_blake2_256_verify_proof_version_1(context unsafe.Pointer, a C.int32_t, b, c, d C.int64_t) C.int32_t { // skipcq: RVV-B0012 +func ext_trie_blake2_256_verify_proof_version_1(context unsafe.Pointer, rootSpan C.int32_t, proofSpan, keySpan, valueSpan C.int64_t) C.int32_t { logger.Debug("[ext_trie_blake2_256_verify_proof_version_1] executing...") - logger.Warn("[ext_trie_blake2_256_verify_proof_version_1] unimplemented") - return 0 + + instanceContext := wasm.IntoInstanceContext(context) + + toDecProofs := asMemorySlice(instanceContext, proofSpan) + var decProofs [][]byte + err := scale.Unmarshal(toDecProofs, &decProofs) + if err != nil { + logger.Error("[ext_trie_blake2_256_verify_proof_version_1]", "error", err) + return C.int32_t(0) + } + + key := asMemorySlice(instanceContext, keySpan) + value := asMemorySlice(instanceContext, valueSpan) + + mem := instanceContext.Memory().Data() + trieRoot := mem[rootSpan : rootSpan+32] + + exists, err := trie.VerifyProof(decProofs, trieRoot, []trie.Pair{{Key: key, Value: value}}) + if err != nil { + logger.Error("[ext_trie_blake2_256_verify_proof_version_1]", "error", err) + return C.int32_t(0) + } + + var result C.int32_t = 0 + if exists { + result = 1 + } + + return result } //export ext_misc_print_hex_version_1 @@ -2131,7 +2158,7 @@ func toKillStorageResultEnum(allRemoved bool, numRemoved uint32) ([]byte, error) // Wraps slice in optional.FixedSizeBytes and copies result to wasm memory. Returns resulting 64bit span descriptor func toWasmMemoryFixedSizeOptional(context wasm.InstanceContext, data []byte) (int64, error) { var opt [64]byte - copy(opt[:], data[:]) + copy(opt[:], data) enc, err := scale.Marshal(&opt) if err != nil { return 0, err diff --git a/lib/runtime/wasmer/imports_test.go b/lib/runtime/wasmer/imports_test.go index e921dd503c..dd7ab7affa 100644 --- a/lib/runtime/wasmer/imports_test.go +++ b/lib/runtime/wasmer/imports_test.go @@ -19,10 +19,12 @@ package wasmer import ( "bytes" "encoding/binary" + "io/ioutil" "os" "sort" "testing" + "github.com/ChainSafe/chaindb" log "github.com/ChainSafe/log15" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -1668,3 +1670,99 @@ func Test_ext_trie_blake2_256_root_version_1(t *testing.T) { expected := tt.MustHash() require.Equal(t, expected[:], hash) } + +func Test_ext_trie_blake2_256_verify_proof_version_1(t *testing.T) { + t.Parallel() + + tmp, err := ioutil.TempDir("", "*-test-trie") + require.NoError(t, err) + + defer os.RemoveAll(tmp) + + memdb, err := chaindb.NewBadgerDB(&chaindb.Config{ + InMemory: true, + DataDir: tmp, + }) + require.NoError(t, err) + + otherTrie := trie.NewEmptyTrie() + otherTrie.Put([]byte("simple"), []byte("cat")) + + otherHash, err := otherTrie.Hash() + require.NoError(t, err) + + tr := trie.NewEmptyTrie() + tr.Put([]byte("do"), []byte("verb")) + tr.Put([]byte("domain"), []byte("website")) + tr.Put([]byte("other"), []byte("random")) + tr.Put([]byte("otherwise"), []byte("randomstuff")) + tr.Put([]byte("cat"), []byte("another animal")) + + err = tr.Store(memdb) + require.NoError(t, err) + + hash, err := tr.Hash() + require.NoError(t, err) + + keys := [][]byte{ + []byte("do"), + []byte("domain"), + []byte("other"), + []byte("otherwise"), + []byte("cat"), + } + + root := hash.ToBytes() + otherRoot := otherHash.ToBytes() + + proof, err := trie.GenerateProof(root, keys, memdb) + require.NoError(t, err) + + testcases := map[string]struct { + root, key, value []byte + proof [][]byte + expect bool + }{ + "Proof should be true": {root: root, key: []byte("do"), proof: proof, value: []byte("verb"), expect: true}, + "Root empty, proof should be false": {root: []byte{}, key: []byte("do"), proof: proof, value: []byte("verb"), expect: false}, + "Other root, proof should be false": {root: otherRoot, key: []byte("do"), proof: proof, value: []byte("verb"), expect: false}, + "Value empty, proof should be true": {root: root, key: []byte("do"), proof: proof, value: nil, expect: true}, + "Unknow key, proof should be false": {root: root, key: []byte("unknow"), proof: proof, value: nil, expect: false}, + "Key and value unknow, proof should be false": {root: root, key: []byte("unknow"), proof: proof, value: []byte("unknow"), expect: false}, + "Empty proof, should be false": {root: root, key: []byte("do"), proof: [][]byte{}, value: nil, expect: false}, + } + + inst := NewTestInstance(t, runtime.HOST_API_TEST_RUNTIME) + + for name, testcase := range testcases { + t.Run(name, func(t *testing.T) { + t.Parallel() + + hashEnc, err := scale.Marshal(testcase.root) + require.NoError(t, err) + + args := []byte{} + args = append(args, hashEnc...) + + encProof, err := scale.Marshal(testcase.proof) + require.NoError(t, err) + args = append(args, encProof...) + + keyEnc, err := scale.Marshal(testcase.key) + require.NoError(t, err) + args = append(args, keyEnc...) + + valueEnc, err := scale.Marshal(testcase.value) + require.NoError(t, err) + args = append(args, valueEnc...) + + res, err := inst.Exec("rtm_ext_trie_blake2_256_verify_proof_version_1", args) + require.NoError(t, err) + + var got bool + err = scale.Unmarshal(res, &got) + require.NoError(t, err) + require.Equal(t, testcase.expect, got) + }) + } +} diff --git a/lib/runtime/wasmer/instance_test.go b/lib/runtime/wasmer/instance_test.go index e7cf4c86ed..ebff7a099e 100644 --- a/lib/runtime/wasmer/instance_test.go +++ b/lib/runtime/wasmer/instance_test.go @@ -49,7 +49,7 @@ func TestPointerSize(t *testing.T) { func TestInstance_CheckRuntimeVersion(t *testing.T) { instance := NewTestInstance(t, runtime.NODE_RUNTIME) - _, err := runtime.GetRuntimeBlob(runtime.POLKADOT_RUNTIME_FP, runtime.POLKADOT_RUNTIME_URL) + err := runtime.GetRuntimeBlob(runtime.POLKADOT_RUNTIME_FP, runtime.POLKADOT_RUNTIME_URL) require.NoError(t, err) fp, err := filepath.Abs(runtime.POLKADOT_RUNTIME_FP) require.NoError(t, err) diff --git a/lib/runtime/wasmer/test_helpers.go b/lib/runtime/wasmer/test_helpers.go index 81e89151ea..c62f02a36c 100644 --- a/lib/runtime/wasmer/test_helpers.go +++ b/lib/runtime/wasmer/test_helpers.go @@ -60,7 +60,7 @@ func NewTestInstanceWithRole(t *testing.T, targetRuntime string, role byte) *Ins func setupConfig(t *testing.T, targetRuntime string, tt *trie.Trie, lvl log.Lvl, role byte) (string, *Config) { testRuntimeFilePath, testRuntimeURL := runtime.GetRuntimeVars(targetRuntime) - _, err := runtime.GetRuntimeBlob(testRuntimeFilePath, testRuntimeURL) + err := runtime.GetRuntimeBlob(testRuntimeFilePath, testRuntimeURL) require.Nil(t, err, "Fail: could not get runtime", "targetRuntime", targetRuntime) s, err := storage.NewTrieState(tt) diff --git a/lib/trie/proof.go b/lib/trie/proof.go index a4a83b919f..b63538ab48 100644 --- a/lib/trie/proof.go +++ b/lib/trie/proof.go @@ -27,21 +27,23 @@ import ( ) var ( - // ErrEmptyTrieRoot occurs when trying to craft a prove with an empty trie root + // ErrEmptyTrieRoot ... ErrEmptyTrieRoot = errors.New("provided trie must have a root") - // ErrValueNotFound indicates that a returned verify proof value doesnt match with the expected value on items array + // ErrValueNotFound ... ErrValueNotFound = errors.New("expected value not found in the trie") - // ErrDuplicateKeys not allowed to verify proof with duplicate keys + // ErrKeyNotFound ... + ErrKeyNotFound = errors.New("expected key not found in the trie") + + // ErrDuplicateKeys ... ErrDuplicateKeys = errors.New("duplicate keys on verify proof") - // ErrLoadFromProof occurs when there are problems with the proof slice while building the partial proof trie + // ErrLoadFromProof ... ErrLoadFromProof = errors.New("failed to build the proof trie") ) // GenerateProof receive the keys to proof, the trie root and a reference to database -// will func GenerateProof(root []byte, keys [][]byte, db chaindb.Database) ([][]byte, error) { trackedProofs := make(map[string][]byte) @@ -100,9 +102,11 @@ func VerifyProof(proof [][]byte, root []byte, items []Pair) (bool, error) { for _, item := range items { recValue := proofTrie.Get(item.Key) - + if recValue == nil { + return false, ErrKeyNotFound + } // here we need to compare value only if the caller pass the value - if item.Value != nil && !bytes.Equal(item.Value, recValue) { + if len(item.Value) > 0 && !bytes.Equal(item.Value, recValue) { return false, ErrValueNotFound } }