Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jit: support unreachable instruciton, and stack trace. #70

Merged
merged 7 commits into from
Dec 10, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 79 additions & 29 deletions wasm/jit/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"fmt"
"math"
"reflect"
"strings"
"unsafe"

"github.com/tetratelabs/wazero/wasm"
Expand Down Expand Up @@ -34,20 +35,43 @@ type engine struct {
compiledWasmFunctions []*compiledWasmFunction
compiledWasmFunctionIndex map[*wasm.FunctionInstance]int64
// Store the host functions and indexes.
hostFunctions []hostFunction
hostFunctionIndex map[*wasm.FunctionInstance]int64
compiledHostFunctions []*compiledHostFunction
compiledHostFunctionIndex map[*wasm.FunctionInstance]int64
}

type hostFunction = func(ctx *wasm.HostFunctionCallContext)

func (e *engine) Call(f *wasm.FunctionInstance, args ...uint64) (returns []uint64, err error) {
prevFrame := e.callFrameStack
defer func() {
mathetake marked this conversation as resolved.
Show resolved Hide resolved
if v := recover(); v != nil {
top := e.callFrameStack
var traces []string
mathetake marked this conversation as resolved.
Show resolved Hide resolved
var counter int
for top != prevFrame {
traces = append(traces, fmt.Sprintf("\t%d: %s", counter, top.getFunctionName()))
top = top.caller
counter++
// TODO: include DWARF symbols.
mathetake marked this conversation as resolved.
Show resolved Hide resolved
}
err2, ok := v.(error)
if ok {
err = fmt.Errorf("wasm runtime error: %w", err2)
} else {
err = fmt.Errorf("wasm runtime error: %v", v)
}

if len(traces) > 0 {
err = fmt.Errorf("%w\nwasm backtrace:\n%s", err, strings.Join(traces, "\n"))
}
}
}()

for _, arg := range args {
e.push(arg)
}
// Note that there's no conflict between e.hostFunctionIndex and e.compiledWasmFunctionIndex,
// meaning that each *wasm.FunctionInstance is assigned to either host function index or wasm function one.
if index, ok := e.hostFunctionIndex[f]; ok {
e.hostFunctions[index](&wasm.HostFunctionCallContext{Memory: f.ModuleInstance.Memory})
if index, ok := e.compiledHostFunctionIndex[f]; ok {
e.compiledHostFunctions[index].f(&wasm.HostFunctionCallContext{Memory: f.ModuleInstance.Memory})
} else if index, ok := e.compiledWasmFunctionIndex[f]; ok {
f := e.compiledWasmFunctions[index]
e.exec(f)
Expand All @@ -70,11 +94,11 @@ func (e *engine) PreCompile(fs []*wasm.FunctionInstance) error {
var newUniqueHostFunctions, newUniqueWasmFunctions int
for _, f := range fs {
if f.HostFunction != nil {
if _, ok := e.hostFunctionIndex[f]; ok {
if _, ok := e.compiledHostFunctionIndex[f]; ok {
continue
}
id := getNewID(e.hostFunctionIndex)
e.hostFunctionIndex[f] = id
id := getNewID(e.compiledHostFunctionIndex)
e.compiledHostFunctionIndex[f] = id
newUniqueHostFunctions++
} else {
if _, ok := e.compiledWasmFunctionIndex[f]; ok {
Expand All @@ -85,9 +109,9 @@ func (e *engine) PreCompile(fs []*wasm.FunctionInstance) error {
newUniqueWasmFunctions++
}
}
e.hostFunctions = append(
e.hostFunctions,
make([]hostFunction, newUniqueHostFunctions)...,
e.compiledHostFunctions = append(
e.compiledHostFunctions,
make([]*compiledHostFunction, newUniqueHostFunctions)...,
)
e.compiledWasmFunctions = append(
e.compiledWasmFunctions,
Expand All @@ -102,8 +126,8 @@ func getNewID(idMap map[*wasm.FunctionInstance]int64) int64 {

func (e *engine) Compile(f *wasm.FunctionInstance) error {
if f.HostFunction != nil {
id := e.hostFunctionIndex[f]
if e.hostFunctions[id] != nil {
id := e.compiledHostFunctionIndex[f]
if e.compiledHostFunctions[id] != nil {
// Already compiled.
return nil
}
Expand Down Expand Up @@ -140,7 +164,7 @@ func (e *engine) Compile(f *wasm.FunctionInstance) error {
}
}
}
e.hostFunctions[id] = hf
e.compiledHostFunctions[id] = &compiledHostFunction{f: hf, name: f.Name}
} else {
id := e.compiledWasmFunctionIndex[f]
if e.compiledWasmFunctions[id] != nil {
Expand All @@ -166,7 +190,7 @@ func newEngine() *engine {
e := &engine{
stack: make([]uint64, initialStackSize),
compiledWasmFunctionIndex: make(map[*wasm.FunctionInstance]int64),
hostFunctionIndex: make(map[*wasm.FunctionInstance]int64),
compiledHostFunctionIndex: make(map[*wasm.FunctionInstance]int64),
}
return e
}
Expand Down Expand Up @@ -195,6 +219,8 @@ const (
jitCallStatusCodeCallBuiltInFunction
// jitCallStatusCodeCallWasmFunction means the jitcall returns to make a host function call.
jitCallStatusCodeCallHostFunction
// jitCallStatusCodeUnreachable means the function invocation reaches "unreachable" instruction.
jitCallStatusCodeUnreachable
// TODO: trap, etc?
)

Expand All @@ -208,6 +234,8 @@ func (s jitCallStatusCode) String() (ret string) {
ret = "call_builtin_function"
case jitCallStatusCodeCallHostFunction:
ret = "call_host_function"
case jitCallStatusCodeUnreachable:
ret = "unreachable"
}
return
}
Expand All @@ -226,18 +254,34 @@ type callFrame struct {
continuationAddress uintptr
continuationStackPointer uint64
baseStackPointer uint64
f *compiledWasmFunction
wasmFunction *compiledWasmFunction
hostFunction *compiledHostFunction
caller *callFrame
}

func (c *callFrame) String() string {
return fmt.Sprintf(
"[continuation address=%d, continuation stack pointer=%d, base stack pointer=%d]",
c.continuationAddress, c.continuationStackPointer, c.baseStackPointer,
"[%s: continuation address=%d, continuation stack pointer=%d, base stack pointer=%d]",
c.getFunctionName(), c.continuationAddress, c.continuationStackPointer, c.baseStackPointer,
)
}

func (c *callFrame) getFunctionName() string {
if c.wasmFunction != nil {
return c.wasmFunction.originalFunctionInstance.Name
} else {
return c.hostFunction.name
}
}

type compiledHostFunction = struct {
f func(ctx *wasm.HostFunctionCallContext)
name string
}

type compiledWasmFunction struct {
// FunctionInstance from which this is compiled.
originalFunctionInstance *wasm.FunctionInstance
mathetake marked this conversation as resolved.
Show resolved Hide resolved
// inputs,returns represents the number of input/returns of function.
inputs, returns uint64
// codeSegment is holding the compiled native code as a byte slice.
Expand Down Expand Up @@ -280,7 +324,7 @@ func (e *engine) maybeGrowStack(maxStackPointer uint64) {
func (e *engine) exec(f *compiledWasmFunction) {
e.callFrameStack = &callFrame{
continuationAddress: f.codeInitialAddress,
f: f,
wasmFunction: f,
caller: nil,
continuationStackPointer: f.inputs,
}
Expand All @@ -299,7 +343,7 @@ func (e *engine) exec(f *compiledWasmFunction) {
jitcall(
currentFrame.continuationAddress,
uintptr(unsafe.Pointer(e)),
currentFrame.f.memoryAddress,
currentFrame.wasmFunction.memoryAddress,
)

// Check the status code from JIT code.
Expand All @@ -319,13 +363,13 @@ func (e *engine) exec(f *compiledWasmFunction) {
nextFunc := e.compiledWasmFunctions[e.functionCallIndex]
// Calculate the continuation address so
// we can resume this caller function frame.
currentFrame.continuationAddress = currentFrame.f.codeInitialAddress + e.continuationAddressOffset
currentFrame.continuationAddress = currentFrame.wasmFunction.codeInitialAddress + e.continuationAddressOffset
currentFrame.continuationStackPointer = e.currentStackPointer + nextFunc.returns - nextFunc.inputs
currentFrame.baseStackPointer = e.currentBaseStackPointer
// Create the callee frame.
frame := &callFrame{
continuationAddress: nextFunc.codeInitialAddress,
f: nextFunc,
wasmFunction: nextFunc,
// Set the caller frame so we can return back to the current frame!
caller: currentFrame,
// Set the base pointer to the beginning of the function inputs
Expand All @@ -339,17 +383,23 @@ func (e *engine) exec(f *compiledWasmFunction) {
// Set the stack pointer so that base+sp would point to the top of function inputs.
e.currentStackPointer = nextFunc.inputs
case jitCallStatusCodeCallBuiltInFunction:
// TODO: check the signature and modify stack pointer.
switch e.functionCallIndex {
case builtinFunctionIndexGrowMemory:
v := e.pop()
e.memoryGrow(currentFrame.f.memory, v)
e.memoryGrow(currentFrame.wasmFunction.memory, v)
}
currentFrame.continuationAddress = currentFrame.f.codeInitialAddress + e.continuationAddressOffset
currentFrame.continuationAddress = currentFrame.wasmFunction.codeInitialAddress + e.continuationAddressOffset
case jitCallStatusCodeCallHostFunction:
e.hostFunctions[e.functionCallIndex](&wasm.HostFunctionCallContext{Memory: f.memory})
// TODO: check the signature and modify stack pointer.
currentFrame.continuationAddress = currentFrame.f.codeInitialAddress + e.continuationAddressOffset
targetHostFunction := e.compiledHostFunctions[e.functionCallIndex]
currentFrame.continuationAddress = currentFrame.wasmFunction.codeInitialAddress + e.continuationAddressOffset
// Push the call frame for this host function.
e.callFrameStack = &callFrame{hostFunction: targetHostFunction, caller: currentFrame}
// Call into the host function.
targetHostFunction.f(&wasm.HostFunctionCallContext{Memory: f.memory})
// Pop the call frame.
e.callFrameStack = currentFrame
case jitCallStatusCodeUnreachable:
panic("unreachable")
}
}
}
Expand Down
51 changes: 30 additions & 21 deletions wasm/jit/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ import (

"github.com/stretchr/testify/require"

"github.com/tetratelabs/wazero/wasi"
"github.com/tetratelabs/wazero/wasm"
"github.com/tetratelabs/wazero/wasm/wazeroir"
)

// Ensures that the offset consts do not drift when we manipulate the engine struct.
Expand All @@ -32,25 +30,36 @@ func TestEngine_fibonacci(t *testing.T) {
require.NoError(t, err)
mod, err := wasm.DecodeModule(buf)
require.NoError(t, err)
store := wasm.NewStore(wazeroir.NewEngine())
require.NoError(t, err)
err = wasi.NewEnvironment().Register(store)
store := wasm.NewStore(NewEngine())
require.NoError(t, err)
err = store.Instantiate(mod, "test")
require.NoError(t, err)
m, ok := store.ModuleInstances["test"]
require.True(t, ok)
exp, ok := m.Exports["fib"]
require.True(t, ok)
f := exp.Function
eng := newEngine()
err = eng.PreCompile([]*wasm.FunctionInstance{f})
out, _, err := store.CallFunction("test", "fib", 20)
require.NoError(t, err)
err = eng.Compile(f)
require.Equal(t, uint64(10946), out[0])
}

func TestEngine_unreachable(t *testing.T) {
if runtime.GOARCH != "amd64" {
t.Skip()
}
buf, err := os.ReadFile("testdata/unreachable.wasm")
require.NoError(t, err)
out, err := eng.Call(f, 20)
mod, err := wasm.DecodeModule(buf)
require.NoError(t, err)
require.Equal(t, uint64(10946), out[0])
store := wasm.NewStore(NewEngine())
require.NoError(t, err)
err = store.Instantiate(mod, "test")
require.NoError(t, err)
_, _, err = store.CallFunction("test", "cause_unreachable")
exp := `wasm runtime error: unreachable
wasm backtrace:
0: three
1: two
2: one
3: cause_unreachable`
require.Error(t, err)
require.Equal(t, exp, err.Error())
}

func TestEngine_PreCompile(t *testing.T) {
Expand All @@ -70,18 +79,18 @@ func TestEngine_PreCompile(t *testing.T) {
// Check the indexes.
require.Len(t, eng.compiledWasmFunctions, 3)
require.Len(t, eng.compiledWasmFunctionIndex, 3)
require.Len(t, eng.hostFunctions, 1)
require.Len(t, eng.hostFunctionIndex, 1)
require.Len(t, eng.compiledHostFunctions, 1)
require.Len(t, eng.compiledHostFunctionIndex, 1)
prevCompiledFunctions := make([]*compiledWasmFunction, len(eng.compiledWasmFunctions))
prevHostFunctions := make([]hostFunction, len(eng.hostFunctions))
prevHostFunctions := make([]*compiledHostFunction, len(eng.compiledHostFunctions))
copy(prevCompiledFunctions, eng.compiledWasmFunctions)
copy(prevHostFunctions, eng.hostFunctions)
copy(prevHostFunctions, eng.compiledHostFunctions)
err = eng.PreCompile(fs)
// Precompiling same functions should be noop.
require.NoError(t, err)
require.Len(t, eng.compiledWasmFunctionIndex, 3)
require.Len(t, eng.hostFunctionIndex, 1)
require.Equal(t, prevHostFunctions, eng.hostFunctions)
require.Len(t, eng.compiledHostFunctionIndex, 1)
require.Equal(t, prevHostFunctions, eng.compiledHostFunctions)
require.Equal(t, prevCompiledFunctions, eng.compiledWasmFunctions)
}

Expand Down
21 changes: 14 additions & 7 deletions wasm/jit/jit_amd64.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (e *engine) compileWasmFunction(f *wasm.FunctionInstance) (*compiledWasmFun
for _, op := range ir.Operations {
switch o := op.(type) {
case *wazeroir.OperationUnreachable:
return nil, fmt.Errorf("unsupported operation in JIT compiler: %v", o)
builder.handleUnreachable()
case *wazeroir.OperationLabel:
if err := builder.handleLabel(o); err != nil {
return nil, fmt.Errorf("error handling label operation: %w", err)
Expand Down Expand Up @@ -224,11 +224,12 @@ func (e *engine) compileWasmFunction(f *wasm.FunctionInstance) (*compiledWasmFun

func (b *amd64Builder) newCompiledWasmFunction(code []byte) *compiledWasmFunction {
cf := &compiledWasmFunction{
codeSegment: code,
inputs: uint64(len(b.f.Signature.InputTypes)),
returns: uint64(len(b.f.Signature.ReturnTypes)),
memory: b.f.ModuleInstance.Memory,
maxStackPointer: b.locationStack.maxStackPointer,
originalFunctionInstance: b.f,
codeSegment: code,
inputs: uint64(len(b.f.Signature.InputTypes)),
returns: uint64(len(b.f.Signature.ReturnTypes)),
memory: b.f.ModuleInstance.Memory,
maxStackPointer: b.locationStack.maxStackPointer,
}
if cf.memory != nil {
cf.memoryAddress = uintptr(unsafe.Pointer(&cf.memory.Buffer[0]))
Expand Down Expand Up @@ -298,6 +299,12 @@ func (b *amd64Builder) newProg() (prog *obj.Prog) {
return
}

func (b *amd64Builder) handleUnreachable() {
b.releaseAllRegistersToStack()
b.setJITStatus(jitCallStatusCodeUnreachable)
b.returnFunction()
}

func (b *amd64Builder) handleBr(o *wazeroir.OperationBr) error {
if o.Target.IsReturnTarget() {
// Release all the registers as our calling convention requires the callee-save.
Expand Down Expand Up @@ -491,7 +498,7 @@ func (b *amd64Builder) handleLabel(o *wazeroir.OperationLabel) error {
func (b *amd64Builder) handleCall(o *wazeroir.OperationCall) error {
target := b.f.ModuleInstance.Functions[o.FunctionIndex]
if target.HostFunction != nil {
index := b.eng.hostFunctionIndex[target]
index := b.eng.compiledHostFunctionIndex[target]
b.callHostFunctionFromConstIndex(index)
} else {
index := b.eng.compiledWasmFunctionIndex[target]
Expand Down
Loading