Skip to content

Commit

Permalink
Merge pull request ethereum#7 from CortexFoundation/feat-infer
Browse files Browse the repository at this point in the history
add flags for feature: opInfer and unit test for opInfer
  • Loading branch information
SiNZeRo authored Jun 5, 2018
2 parents 7d4c070 + 928e360 commit 58254cb
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 46 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,9 @@ profile.cov
/dashboard/assets/package-lock.json

**/yarn-error.log

#python_restful
python_restful/input_data
python_restful/model
python_restful/upload
python_restful/__pycache__/
1 change: 1 addition & 0 deletions cmd/geth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ var (
utils.GpoPercentileFlag,
utils.ExtraDataFlag,
configFileFlag,
utils.ModelCallInterfaceFlag,
}

rpcFlags = []cli.Flag{
Expand Down
10 changes: 9 additions & 1 deletion cmd/utils/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,11 @@ var (
Usage: "Minimum POW accepted",
Value: whisper.DefaultMinimumPoW,
}
ModelCallInterfaceFlag = cli.StringFlag{
Name: "cvm.inferuri",
Usage: "infer uri",
Value: "http://127.0.0.1:5000/infer",
}
)

// MakeDataDir retrieves the currently requested data directory, terminating
Expand Down Expand Up @@ -1247,7 +1252,10 @@ func MakeChain(ctx *cli.Context, stack *node.Node) (chain *core.BlockChain, chai
if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) {
cache.TrieNodeLimit = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100
}
vmcfg := vm.Config{EnablePreimageRecording: ctx.GlobalBool(VMEnableDebugFlag.Name)}
vmcfg := vm.Config{
EnablePreimageRecording: ctx.GlobalBool(VMEnableDebugFlag.Name),
InferURI: ctx.GlobalString(ModelCallInterfaceFlag.Name),
}
chain, err = core.NewBlockChain(chainDb, cache, config, engine, vmcfg)
if err != nil {
Fatalf("Can't create BlockChain: %v", err)
Expand Down
30 changes: 4 additions & 26 deletions core/vm/evm.go
Original file line number Diff line number Diff line change
Expand Up @@ -422,45 +422,23 @@ func (evm *EVM) ChainConfig() *params.ChainConfig { return evm.chainConfig }
// Interpreter returns the EVM interpreter
func (evm *EVM) Interpreter() *Interpreter { return evm.interpreter }

func (evm *EVM) CallExternal(call_type string, input [][]byte) []byte {
if call_type == "infer" {
model_meta_hash := input[0]
input_meta_hash := input[1]
requestBody := fmt.Sprintf(`{"model_addr":"%x", "input_addr":"%x"}`, model_meta_hash, input_meta_hash)
fmt.Println(requestBody)
resp, err := resty.R().
SetHeader("Content-Type", "application/json").
SetBody(requestBody).
Post("http://127.0.0.1:5000/infer")
if err != nil {
return []byte("ERROR0")
}
fmt.Println(resp.String())
js, _ := simplejson.NewJson([]byte(resp.String()))
int_output_tmp, out_err := js.Get("info").String()
int_output, err := strconv.Atoi(int_output_tmp)
fmt.Println("out: ", int_output, "err:", out_err, " resp", resp.String(), "js", js)
return BinaryWrite(int64(int_output))
}
return []byte{0}
}

// infer function that returns an int64 as output, can be used a categorical output
func (evm *EVM) Infer(model_meta_hash []byte, input_meta_hash []byte) ([]byte, error) {
requestBody := fmt.Sprintf(`{"model_addr":"%x", "input_addr":"%x"}`, model_meta_hash, input_meta_hash)
fmt.Println(requestBody)
resp, err := resty.R().
SetHeader("Content-Type", "application/json").
SetBody(requestBody).
Post("http://127.0.0.1:5000/infer")
Post(evm.interpreter.cfg.InferURI)
if err != nil {
return []byte{}, errors.New("evm.Infer: External Call Error")
}
fmt.Println(resp.String())
js, _ := simplejson.NewJson([]byte(resp.String()))
int_output_tmp, out_err := js.Get("info").String()
if out_err != nil {
return []byte{}, errors.New("evm.Infer: External Call Error")
}
int_output, err := strconv.Atoi(int_output_tmp)
fmt.Println("out: ", int_output, "err:", out_err, " resp", resp.String(), "js", js)
buf := new(bytes.Buffer)
if err := binary.Write(buf, binary.BigEndian, int64(int_output)); err != nil {
return []byte{}, errors.New("evm.Infer: Type Conversion Error")
Expand Down
14 changes: 4 additions & 10 deletions core/vm/instructions.go
Original file line number Diff line number Diff line change
Expand Up @@ -649,32 +649,26 @@ func opInfer(pc *uint64, evm *EVM, contract *Contract, memory *Memory, stack *St

_modelMeta := evm.StateDB.GetCode(modelAddr)
_inputMeta := evm.StateDB.GetCode(inputAddr)

// fmt.Println("_model: ", _modelMeta)
// fmt.Println("_input: ", _inputMeta)
var (
modelMeta *types.ModelMeta
inputMeta *types.InputMeta
)
var err error
if modelMeta, err = types.ParseModelMeta(_modelMeta); err != nil {
stack.push(evm.interpreter.intPool.get().SetUint64(1))
stack.push(evm.interpreter.intPool.getZero())
return nil, err
}
if inputMeta, err = types.ParseInputMeta(_inputMeta); err != nil {
stack.push(evm.interpreter.intPool.get().SetUint64(1))
stack.push(evm.interpreter.intPool.getZero())
return nil, err
}

output, err := evm.Infer(modelMeta.Hash.Bytes(), inputMeta.Hash.Bytes())
if err != nil {
stack.push(evm.interpreter.intPool.getZero())
return nil, err
} else {
stack.push(evm.interpreter.intPool.get().SetUint64(1))
}

fmt.Println("model, input", modelMeta.Hash, inputMeta.Hash)
fmt.Println("model, input", string(modelMeta.Hash.Bytes()), string(inputMeta.Hash.Bytes()))
output, err := evm.Infer(modelMeta.Hash.Bytes(), inputMeta.Hash.Bytes())
memory.Set(offset.Uint64(), size.Uint64(), output)
_, _ = inputMeta, modelMeta
return nil, nil
Expand Down
11 changes: 2 additions & 9 deletions core/vm/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type Config struct {
// may be left uninitialised and will be set to the default
// table.
JumpTable [256]operation
// uri for remote infer service
InferURI string
}

// Interpreter is used to run Ethereum based contracts and will utilise the
Expand All @@ -61,27 +63,18 @@ func NewInterpreter(evm *EVM, cfg Config) *Interpreter {
// We use the STOP instruction whether to see
// the jump table was initialised. If it was not
// we'll set the default jump table.
fmt.Println("evm.BlockNumber", evm.BlockNumber)
fmt.Println(evm.ChainConfig().IsConstantinople(evm.BlockNumber))
fmt.Println(evm.ChainConfig().IsByzantium(evm.BlockNumber))
fmt.Println(evm.ChainConfig().IsHomestead(evm.BlockNumber))
if !cfg.JumpTable[STOP].valid {
switch {
case evm.ChainConfig().IsConstantinople(evm.BlockNumber):
cfg.JumpTable = constantinopleInstructionSet
fmt.Println("constantinopleInstructionSet")
case evm.ChainConfig().IsByzantium(evm.BlockNumber):
cfg.JumpTable = byzantiumInstructionSet
fmt.Println("byzantiumInstructionSet")
case evm.ChainConfig().IsHomestead(evm.BlockNumber):
cfg.JumpTable = homesteadInstructionSet
fmt.Println("homesteadInstructionSet")
default:
cfg.JumpTable = frontierInstructionSet
fmt.Println("frontierInstructionSet")
}
}

return &Interpreter{
evm: evm,
cfg: cfg,
Expand Down
114 changes: 114 additions & 0 deletions core/vm/runtime/runner_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package runtime

import (
"encoding/hex"
"fmt"
"math/big"
"os"
"testing"

"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/core/state"
"github.com/ethereum/go-ethereum/core/types"
"github.com/ethereum/go-ethereum/core/vm"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/ethdb"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
)

func TestRunCmd(t *testing.T) {
glogger := log.NewGlogHandler(log.StreamHandler(os.Stderr, log.TerminalFormat(false)))
log.Root().SetHandler(glogger)

var (
tracer vm.Tracer
debugLogger *vm.StructLogger
statedb *state.StateDB
sender = common.BytesToAddress([]byte("sender"))
receiver = common.BytesToAddress([]byte("receiver"))
blockNumber uint64
)
logconfig := &vm.LogConfig{
Debug: true,
}
debugLogger = vm.NewStructLogger(logconfig)
tracer = debugLogger
{
statedb, _ = state.New(common.Hash{}, state.NewDatabase(ethdb.NewMemDatabase()))
}
statedb.CreateAccount(sender)
mh, _ := hex.DecodeString("5c4d1f84063be8e25e83da6452b1821926548b3c2a2a903a0724e14d5c917b00")
ih, _ := hex.DecodeString("c0a1f3c82e11e314822679e4834e3bc575bd017d12d888acda4a851a62d261dc")
testModelMeta, _ := rlp.EncodeToBytes(
&types.ModelMeta{
Hash: common.BytesToHash(mh),
RawSize: 10000,
InputShape: []uint64{10, 1},
OutputShape: []uint64{1},
Gas: 100000,
AuthorAddress: common.BytesToAddress(crypto.Keccak256([]byte{0x2, 0x2})),
})
// new a modelmeta at 0x1001 and new a datameta at 0x2001

testInputMeta, _ := rlp.EncodeToBytes(
&types.InputMeta{
Hash: common.BytesToHash(ih),
RawSize: 10000,
Shape: []uint64{1},
AuthorAddress: common.BytesToAddress(crypto.Keccak256([]byte{0x3})),
})
statedb.SetCode(common.HexToAddress("0x1001"), append([]byte{0x0, 0x1}, []byte(testModelMeta)...))
statedb.SetCode(common.HexToAddress("0x2001"), append([]byte{0x0, 0x2}, []byte(testInputMeta)...))

var (
code []byte
ret []byte
err error
)

code = common.Hex2Bytes("60086000612001611001c0")
input_flag := ""

initialGas := uint64(10000000)
runtimeConfig := Config{
Origin: sender,
State: statedb,
GasLimit: initialGas,
GasPrice: new(big.Int),
Value: new(big.Int),
BlockNumber: new(big.Int).SetUint64(blockNumber),
EVMConfig: vm.Config{
Tracer: tracer,
Debug: true,
InferURI: "http://127.0.0.1:5000/infer",
},
}

if false {
input := append(code, input_flag...)
ret, _, _, err = Create(input, &runtimeConfig)
} else {
if len(code) > 0 {
statedb.SetCode(receiver, code)
}
ret, _, err = Call(receiver, common.Hex2Bytes(input_flag), &runtimeConfig)
}

if true {
if debugLogger != nil {
fmt.Fprintln(os.Stderr, "#### TRACE ####")
vm.WriteTrace(os.Stderr, debugLogger.StructLogs())
}
fmt.Fprintln(os.Stderr, "#### LOGS ####")
vm.WriteLogs(os.Stderr, statedb.Logs())
}

if tracer == nil {
fmt.Printf("0x%x\n", ret)
if err != nil {
fmt.Printf(" error: %v\n", err)
}
}

}

0 comments on commit 58254cb

Please sign in to comment.