Skip to content

Commit

Permalink
jit: support unreachable instruciton, and stack trace. (#70)
Browse files Browse the repository at this point in the history
Signed-off-by: Takeshi Yoneda takeshi@tetrate.io
  • Loading branch information
mathetake authored Dec 10, 2021
1 parent 4ca8afd commit 340e4d4
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 58 deletions.
112 changes: 83 additions & 29 deletions wasm/jit/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"math"
"reflect"
"strings"
"unsafe"

"github.com/tetratelabs/wazero/wasm"
Expand Down Expand Up @@ -34,20 +35,47 @@ type engine struct {
compiledWasmFunctions []*compiledWasmFunction
compiledWasmFunctionIndex map[*wasm.FunctionInstance]int64
// Store the host functions and indexes.
hostFunctions []hostFunction
hostFunctionIndex map[*wasm.FunctionInstance]int64
compiledHostFunctions []*compiledHostFunction
compiledHostFunctionIndex map[*wasm.FunctionInstance]int64
}

type hostFunction = func(ctx *wasm.HostFunctionCallContext)

func (e *engine) Call(f *wasm.FunctionInstance, args ...uint64) (returns []uint64, err error) {
prevFrame := e.callFrameStack
// We ensure that this Call method never panics as
// this Call method is indirectly invoked by embedders via store.CallFunction,
// and we have to make sure that all the runtime errors, including the one happening inside
// host functions, will be capatured as errors, not panics.
defer func() {
if v := recover(); v != nil {
top := e.callFrameStack
var frames []string
var counter int
for top != prevFrame {
frames = append(frames, fmt.Sprintf("\t%d: %s", counter, top.getFunctionName()))
top = top.caller
counter++
// TODO: include DWARF symbols. See #58
}
err2, ok := v.(error)
if ok {
err = fmt.Errorf("wasm runtime error: %w", err2)
} else {
err = fmt.Errorf("wasm runtime error: %v", v)
}

if len(frames) > 0 {
err = fmt.Errorf("%w\nwasm backtrace:\n%s", err, strings.Join(frames, "\n"))
}
}
}()

for _, arg := range args {
e.push(arg)
}
// Note that there's no conflict between e.hostFunctionIndex and e.compiledWasmFunctionIndex,
// meaning that each *wasm.FunctionInstance is assigned to either host function index or wasm function one.
if index, ok := e.hostFunctionIndex[f]; ok {
e.hostFunctions[index](&wasm.HostFunctionCallContext{Memory: f.ModuleInstance.Memory})
if index, ok := e.compiledHostFunctionIndex[f]; ok {
e.compiledHostFunctions[index].f(&wasm.HostFunctionCallContext{Memory: f.ModuleInstance.Memory})
} else if index, ok := e.compiledWasmFunctionIndex[f]; ok {
f := e.compiledWasmFunctions[index]
e.exec(f)
Expand All @@ -70,11 +98,11 @@ func (e *engine) PreCompile(fs []*wasm.FunctionInstance) error {
var newUniqueHostFunctions, newUniqueWasmFunctions int
for _, f := range fs {
if f.HostFunction != nil {
if _, ok := e.hostFunctionIndex[f]; ok {
if _, ok := e.compiledHostFunctionIndex[f]; ok {
continue
}
id := getNewID(e.hostFunctionIndex)
e.hostFunctionIndex[f] = id
id := getNewID(e.compiledHostFunctionIndex)
e.compiledHostFunctionIndex[f] = id
newUniqueHostFunctions++
} else {
if _, ok := e.compiledWasmFunctionIndex[f]; ok {
Expand All @@ -85,9 +113,9 @@ func (e *engine) PreCompile(fs []*wasm.FunctionInstance) error {
newUniqueWasmFunctions++
}
}
e.hostFunctions = append(
e.hostFunctions,
make([]hostFunction, newUniqueHostFunctions)...,
e.compiledHostFunctions = append(
e.compiledHostFunctions,
make([]*compiledHostFunction, newUniqueHostFunctions)...,
)
e.compiledWasmFunctions = append(
e.compiledWasmFunctions,
Expand All @@ -102,8 +130,8 @@ func getNewID(idMap map[*wasm.FunctionInstance]int64) int64 {

func (e *engine) Compile(f *wasm.FunctionInstance) error {
if f.HostFunction != nil {
id := e.hostFunctionIndex[f]
if e.hostFunctions[id] != nil {
id := e.compiledHostFunctionIndex[f]
if e.compiledHostFunctions[id] != nil {
// Already compiled.
return nil
}
Expand Down Expand Up @@ -140,7 +168,7 @@ func (e *engine) Compile(f *wasm.FunctionInstance) error {
}
}
}
e.hostFunctions[id] = hf
e.compiledHostFunctions[id] = &compiledHostFunction{f: hf, name: f.Name}
} else {
id := e.compiledWasmFunctionIndex[f]
if e.compiledWasmFunctions[id] != nil {
Expand All @@ -166,7 +194,7 @@ func newEngine() *engine {
e := &engine{
stack: make([]uint64, initialStackSize),
compiledWasmFunctionIndex: make(map[*wasm.FunctionInstance]int64),
hostFunctionIndex: make(map[*wasm.FunctionInstance]int64),
compiledHostFunctionIndex: make(map[*wasm.FunctionInstance]int64),
}
return e
}
Expand Down Expand Up @@ -195,6 +223,8 @@ const (
jitCallStatusCodeCallBuiltInFunction
// jitCallStatusCodeCallWasmFunction means the jitcall returns to make a host function call.
jitCallStatusCodeCallHostFunction
// jitCallStatusCodeUnreachable means the function invocation reaches "unreachable" instruction.
jitCallStatusCodeUnreachable
// TODO: trap, etc?
)

Expand All @@ -208,6 +238,8 @@ func (s jitCallStatusCode) String() (ret string) {
ret = "call_builtin_function"
case jitCallStatusCodeCallHostFunction:
ret = "call_host_function"
case jitCallStatusCodeUnreachable:
ret = "unreachable"
}
return
}
Expand All @@ -226,18 +258,34 @@ type callFrame struct {
continuationAddress uintptr
continuationStackPointer uint64
baseStackPointer uint64
f *compiledWasmFunction
wasmFunction *compiledWasmFunction
hostFunction *compiledHostFunction
caller *callFrame
}

func (c *callFrame) String() string {
return fmt.Sprintf(
"[continuation address=%d, continuation stack pointer=%d, base stack pointer=%d]",
c.continuationAddress, c.continuationStackPointer, c.baseStackPointer,
"[%s: continuation address=%d, continuation stack pointer=%d, base stack pointer=%d]",
c.getFunctionName(), c.continuationAddress, c.continuationStackPointer, c.baseStackPointer,
)
}

func (c *callFrame) getFunctionName() string {
if c.wasmFunction != nil {
return c.wasmFunction.source.Name
} else {
return c.hostFunction.name
}
}

type compiledHostFunction = struct {
f func(ctx *wasm.HostFunctionCallContext)
name string
}

type compiledWasmFunction struct {
// The source function instance from which this is compiled.
source *wasm.FunctionInstance
// inputs,returns represents the number of input/returns of function.
inputs, returns uint64
// codeSegment is holding the compiled native code as a byte slice.
Expand Down Expand Up @@ -280,7 +328,7 @@ func (e *engine) maybeGrowStack(maxStackPointer uint64) {
func (e *engine) exec(f *compiledWasmFunction) {
e.callFrameStack = &callFrame{
continuationAddress: f.codeInitialAddress,
f: f,
wasmFunction: f,
caller: nil,
continuationStackPointer: f.inputs,
}
Expand All @@ -299,7 +347,7 @@ func (e *engine) exec(f *compiledWasmFunction) {
jitcall(
currentFrame.continuationAddress,
uintptr(unsafe.Pointer(e)),
currentFrame.f.memoryAddress,
currentFrame.wasmFunction.memoryAddress,
)

// Check the status code from JIT code.
Expand All @@ -319,13 +367,13 @@ func (e *engine) exec(f *compiledWasmFunction) {
nextFunc := e.compiledWasmFunctions[e.functionCallIndex]
// Calculate the continuation address so
// we can resume this caller function frame.
currentFrame.continuationAddress = currentFrame.f.codeInitialAddress + e.continuationAddressOffset
currentFrame.continuationAddress = currentFrame.wasmFunction.codeInitialAddress + e.continuationAddressOffset
currentFrame.continuationStackPointer = e.currentStackPointer + nextFunc.returns - nextFunc.inputs
currentFrame.baseStackPointer = e.currentBaseStackPointer
// Create the callee frame.
frame := &callFrame{
continuationAddress: nextFunc.codeInitialAddress,
f: nextFunc,
wasmFunction: nextFunc,
// Set the caller frame so we can return back to the current frame!
caller: currentFrame,
// Set the base pointer to the beginning of the function inputs
Expand All @@ -339,17 +387,23 @@ func (e *engine) exec(f *compiledWasmFunction) {
// Set the stack pointer so that base+sp would point to the top of function inputs.
e.currentStackPointer = nextFunc.inputs
case jitCallStatusCodeCallBuiltInFunction:
// TODO: check the signature and modify stack pointer.
switch e.functionCallIndex {
case builtinFunctionIndexGrowMemory:
v := e.pop()
e.memoryGrow(currentFrame.f.memory, v)
e.memoryGrow(currentFrame.wasmFunction.memory, v)
}
currentFrame.continuationAddress = currentFrame.f.codeInitialAddress + e.continuationAddressOffset
currentFrame.continuationAddress = currentFrame.wasmFunction.codeInitialAddress + e.continuationAddressOffset
case jitCallStatusCodeCallHostFunction:
e.hostFunctions[e.functionCallIndex](&wasm.HostFunctionCallContext{Memory: f.memory})
// TODO: check the signature and modify stack pointer.
currentFrame.continuationAddress = currentFrame.f.codeInitialAddress + e.continuationAddressOffset
targetHostFunction := e.compiledHostFunctions[e.functionCallIndex]
currentFrame.continuationAddress = currentFrame.wasmFunction.codeInitialAddress + e.continuationAddressOffset
// Push the call frame for this host function.
e.callFrameStack = &callFrame{hostFunction: targetHostFunction, caller: currentFrame}
// Call into the host function.
targetHostFunction.f(&wasm.HostFunctionCallContext{Memory: f.memory})
// Pop the call frame.
e.callFrameStack = currentFrame
case jitCallStatusCodeUnreachable:
panic("unreachable")
}
}
}
Expand Down
51 changes: 30 additions & 21 deletions wasm/jit/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ import (

"github.com/stretchr/testify/require"

"github.com/tetratelabs/wazero/wasi"
"github.com/tetratelabs/wazero/wasm"
"github.com/tetratelabs/wazero/wasm/wazeroir"
)

// Ensures that the offset consts do not drift when we manipulate the engine struct.
Expand All @@ -32,25 +30,36 @@ func TestEngine_fibonacci(t *testing.T) {
require.NoError(t, err)
mod, err := wasm.DecodeModule(buf)
require.NoError(t, err)
store := wasm.NewStore(wazeroir.NewEngine())
require.NoError(t, err)
err = wasi.NewEnvironment().Register(store)
store := wasm.NewStore(NewEngine())
require.NoError(t, err)
err = store.Instantiate(mod, "test")
require.NoError(t, err)
m, ok := store.ModuleInstances["test"]
require.True(t, ok)
exp, ok := m.Exports["fib"]
require.True(t, ok)
f := exp.Function
eng := newEngine()
err = eng.PreCompile([]*wasm.FunctionInstance{f})
out, _, err := store.CallFunction("test", "fib", 20)
require.NoError(t, err)
err = eng.Compile(f)
require.Equal(t, uint64(10946), out[0])
}

func TestEngine_unreachable(t *testing.T) {
if runtime.GOARCH != "amd64" {
t.Skip()
}
buf, err := os.ReadFile("testdata/unreachable.wasm")
require.NoError(t, err)
out, err := eng.Call(f, 20)
mod, err := wasm.DecodeModule(buf)
require.NoError(t, err)
require.Equal(t, uint64(10946), out[0])
store := wasm.NewStore(NewEngine())
require.NoError(t, err)
err = store.Instantiate(mod, "test")
require.NoError(t, err)
_, _, err = store.CallFunction("test", "cause_unreachable")
exp := `wasm runtime error: unreachable
wasm backtrace:
0: three
1: two
2: one
3: cause_unreachable`
require.Error(t, err)
require.Equal(t, exp, err.Error())
}

func TestEngine_PreCompile(t *testing.T) {
Expand All @@ -70,18 +79,18 @@ func TestEngine_PreCompile(t *testing.T) {
// Check the indexes.
require.Len(t, eng.compiledWasmFunctions, 3)
require.Len(t, eng.compiledWasmFunctionIndex, 3)
require.Len(t, eng.hostFunctions, 1)
require.Len(t, eng.hostFunctionIndex, 1)
require.Len(t, eng.compiledHostFunctions, 1)
require.Len(t, eng.compiledHostFunctionIndex, 1)
prevCompiledFunctions := make([]*compiledWasmFunction, len(eng.compiledWasmFunctions))
prevHostFunctions := make([]hostFunction, len(eng.hostFunctions))
prevHostFunctions := make([]*compiledHostFunction, len(eng.compiledHostFunctions))
copy(prevCompiledFunctions, eng.compiledWasmFunctions)
copy(prevHostFunctions, eng.hostFunctions)
copy(prevHostFunctions, eng.compiledHostFunctions)
err = eng.PreCompile(fs)
// Precompiling same functions should be noop.
require.NoError(t, err)
require.Len(t, eng.compiledWasmFunctionIndex, 3)
require.Len(t, eng.hostFunctionIndex, 1)
require.Equal(t, prevHostFunctions, eng.hostFunctions)
require.Len(t, eng.compiledHostFunctionIndex, 1)
require.Equal(t, prevHostFunctions, eng.compiledHostFunctions)
require.Equal(t, prevCompiledFunctions, eng.compiledWasmFunctions)
}

Expand Down
11 changes: 9 additions & 2 deletions wasm/jit/jit_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (e *engine) compileWasmFunction(f *wasm.FunctionInstance) (*compiledWasmFun
for _, op := range ir.Operations {
switch o := op.(type) {
case *wazeroir.OperationUnreachable:
return nil, fmt.Errorf("unsupported operation in JIT compiler: %v", o)
builder.handleUnreachable()
case *wazeroir.OperationLabel:
if err := builder.handleLabel(o); err != nil {
return nil, fmt.Errorf("error handling label operation: %w", err)
Expand Down Expand Up @@ -224,6 +224,7 @@ func (e *engine) compileWasmFunction(f *wasm.FunctionInstance) (*compiledWasmFun

func (b *amd64Builder) newCompiledWasmFunction(code []byte) *compiledWasmFunction {
cf := &compiledWasmFunction{
source: b.f,
codeSegment: code,
inputs: uint64(len(b.f.Signature.InputTypes)),
returns: uint64(len(b.f.Signature.ReturnTypes)),
Expand Down Expand Up @@ -298,6 +299,12 @@ func (b *amd64Builder) newProg() (prog *obj.Prog) {
return
}

func (b *amd64Builder) handleUnreachable() {
b.releaseAllRegistersToStack()
b.setJITStatus(jitCallStatusCodeUnreachable)
b.returnFunction()
}

func (b *amd64Builder) handleBr(o *wazeroir.OperationBr) error {
if o.Target.IsReturnTarget() {
// Release all the registers as our calling convention requires the callee-save.
Expand Down Expand Up @@ -491,7 +498,7 @@ func (b *amd64Builder) handleLabel(o *wazeroir.OperationLabel) error {
func (b *amd64Builder) handleCall(o *wazeroir.OperationCall) error {
target := b.f.ModuleInstance.Functions[o.FunctionIndex]
if target.HostFunction != nil {
index := b.eng.hostFunctionIndex[target]
index := b.eng.compiledHostFunctionIndex[target]
b.callHostFunctionFromConstIndex(index)
} else {
index := b.eng.compiledWasmFunctionIndex[target]
Expand Down
Loading

0 comments on commit 340e4d4

Please sign in to comment.