diff --git a/func.go b/func.go index d98afad..d84afd9 100644 --- a/func.go +++ b/func.go @@ -32,7 +32,6 @@ type Caller struct { // which can't be garbage collected. type newMapEntry struct { callback func(*Caller, []Val) ([]Val, *Trap) - nparams int results []*ValType } @@ -73,7 +72,6 @@ func NewFunc( idx := gNewMapSlab.allocate() gNewMap[idx] = newMapEntry{ callback: f, - nparams: len(ty.Params()), results: ty.Results(), } gLock.Unlock() @@ -95,8 +93,8 @@ func goTrampolineNew( caller_id C.size_t, callerPtr *C.wasmtime_caller_t, env C.size_t, - argsPtr *C.wasm_val_t, - resultsPtr *C.wasm_val_t, + argsPtr *C.wasm_val_vec_t, + resultsPtr *C.wasm_val_vec_t, ) *C.wasm_trap_t { idx := int(env) gLock.Lock() @@ -107,9 +105,9 @@ func goTrampolineNew( caller := &Caller{ptr: callerPtr, freelist: freelist} defer func() { caller.ptr = nil }() - params := make([]Val, entry.nparams) + params := make([]Val, int(argsPtr.size)) var val C.wasm_val_t - base := unsafe.Pointer(argsPtr) + base := unsafe.Pointer(argsPtr.data) for i := 0; i < len(params); i++ { ptr := (*C.wasm_val_t)(unsafe.Pointer(uintptr(base) + uintptr(i)*unsafe.Sizeof(val))) params[i] = mkVal(ptr, freelist) @@ -144,7 +142,7 @@ func goTrampolineNew( return trap.ptr() } - base = unsafe.Pointer(resultsPtr) + base = unsafe.Pointer(resultsPtr.data) for i := 0; i < len(results); i++ { ptr := (*C.wasm_val_t)(unsafe.Pointer(uintptr(base) + uintptr(i)*unsafe.Sizeof(val))) C.wasm_val_copy(ptr, results[i].ptr()) @@ -274,8 +272,8 @@ func goTrampolineWrap( caller_id C.size_t, callerPtr *C.wasmtime_caller_t, env C.size_t, - argsPtr *C.wasm_val_t, - resultsPtr *C.wasm_val_t, + argsPtr *C.wasm_val_vec_t, + resultsPtr *C.wasm_val_vec_t, ) *C.wasm_trap_t { // Convert all our parameters to `[]reflect.Value`, taking special care // for `*Caller` but otherwise reading everything through `Val`. @@ -291,7 +289,7 @@ func goTrampolineWrap( ty := entry.callback.Type() params := make([]reflect.Value, ty.NumIn()) - base := unsafe.Pointer(argsPtr) + base := unsafe.Pointer(argsPtr.data) var raw C.wasm_val_t for i := 0; i < len(params); i++ { if ty.In(i) == reflect.TypeOf(caller) { @@ -321,7 +319,7 @@ func goTrampolineWrap( // And now we write all the results into memory depending on the type // of value that was returned. - base = unsafe.Pointer(resultsPtr) + base = unsafe.Pointer(resultsPtr.data) for _, result := range results { ptr := (*C.wasm_val_t)(base) switch val := result.Interface().(type) { @@ -446,51 +444,49 @@ func (f *Func) Call(args ...interface{}) (interface{}, error) { if len(args) > len(params) { return nil, errors.New("too many arguments provided") } - paramsRaw := make([]C.wasm_val_t, len(args)) - synthesizedParams := make([]Val, 0) + paramsVec := C.wasm_val_vec_t{} + C.wasm_val_vec_new_uninitialized(¶msVec, C.size_t(len(args))) for i, param := range args { + var rawVal Val switch val := param.(type) { case int: switch params[i].Kind() { case KindI32: - paramsRaw[i] = *ValI32(int32(val)).ptr() + rawVal = ValI32(int32(val)) case KindI64: - paramsRaw[i] = *ValI64(int64(val)).ptr() + rawVal = ValI64(int64(val)) default: return nil, errors.New("integer provided for non-integer argument") } case int32: - paramsRaw[i] = *ValI32(val).ptr() + rawVal = ValI32(val) case int64: - paramsRaw[i] = *ValI64(val).ptr() + rawVal = ValI64(val) case float32: - paramsRaw[i] = *ValF32(val).ptr() + rawVal = ValF32(val) case float64: - paramsRaw[i] = *ValF64(val).ptr() + rawVal = ValF64(val) case *Func: - ffi := ValFuncref(val) - paramsRaw[i] = *ffi.ptr() - synthesizedParams = append(synthesizedParams, ffi) + rawVal = ValFuncref(val) case Val: - paramsRaw[i] = *val.ptr() + rawVal = val default: - ffi := ValExternref(val) - paramsRaw[i] = *ffi.ptr() - synthesizedParams = append(synthesizedParams, ffi) + rawVal = ValExternref(val) } - } - resultsRaw := make([]C.wasm_val_t, f.ResultArity()) + base := unsafe.Pointer(paramsVec.data) + ptr := rawVal.ptr() + C.wasm_val_copy( + (*C.wasm_val_t)(unsafe.Pointer(uintptr(base)+unsafe.Sizeof(*ptr)*uintptr(i))), + ptr, + ) + runtime.KeepAlive(rawVal) + } - var paramsPtr, resultsPtr *C.wasm_val_t + resultsVec := C.wasm_val_vec_t{} + C.wasm_val_vec_new_uninitialized(&resultsVec, C.size_t(f.ResultArity())) var trap *C.wasm_trap_t - if len(paramsRaw) > 0 { - paramsPtr = ¶msRaw[0] - } - if len(resultsRaw) > 0 { - resultsPtr = &resultsRaw[0] - } // Use our `freelist` as an anchor to get an identifier which our C // shim shoves into thread-local storage and then pops out on the @@ -502,17 +498,14 @@ func (f *Func) Call(args ...interface{}) (interface{}, error) { err := C.go_wasmtime_func_call( f.ptr(), - paramsPtr, - C.size_t(len(paramsRaw)), - resultsPtr, - C.size_t(len(resultsRaw)), + ¶msVec, + &resultsVec, &trap, caller_id, ) runtime.KeepAlive(f) - runtime.KeepAlive(paramsRaw) runtime.KeepAlive(args) - runtime.KeepAlive(synthesizedParams) + C.wasm_val_vec_delete(¶msVec) // Clear our thread's caller id from the global maps now that the call // is finished. @@ -543,15 +536,21 @@ func (f *Func) Call(args ...interface{}) (interface{}, error) { return nil, wrappedTrap } - if len(resultsRaw) == 0 { + if resultsVec.size == 0 { return nil, nil - } else if len(resultsRaw) == 1 { - return takeVal(&resultsRaw[0], f.freelist).Get(), nil + } else if resultsVec.size == 1 { + ret := mkVal(resultsVec.data, f.freelist).Get() + C.wasm_val_vec_delete(&resultsVec) + return ret, nil } else { - results := make([]Val, len(resultsRaw)) - for i := 0; i < len(resultsRaw); i++ { - results[i] = takeVal(&resultsRaw[i], f.freelist) + results := make([]Val, int(resultsVec.size)) + base := unsafe.Pointer(resultsVec.data) + var val C.wasm_val_t + for i := 0; i < int(resultsVec.size); i++ { + ptr := (*C.wasm_val_t)(unsafe.Pointer(uintptr(base) + unsafe.Sizeof(val)*uintptr(i))) + results[i] = mkVal(ptr, f.freelist) } + C.wasm_val_vec_delete(&resultsVec) return results, nil } diff --git a/instance.go b/instance.go index 6b04e81..1623049 100644 --- a/instance.go +++ b/instance.go @@ -26,28 +26,25 @@ type Instance struct { // This will also run the `start` function of the instance, returning an error // if it traps. func NewInstance(store *Store, module *Module, imports []*Extern) (*Instance, error) { - importsRaw := make([]*C.wasm_extern_t, len(imports)) + importsRaw := C.wasm_extern_vec_t{} + C.wasm_extern_vec_new_uninitialized(&importsRaw, C.size_t(len(imports))) + base := unsafe.Pointer(importsRaw.data) for i, imp := range imports { - importsRaw[i] = imp.ptr() - } - var importsRawPtr **C.wasm_extern_t - if len(imports) > 0 { - importsRawPtr = &importsRaw[0] + ptr := C.wasm_extern_copy(imp.ptr()) + *(**C.wasm_extern_t)(unsafe.Pointer(uintptr(base) + unsafe.Sizeof(ptr)*uintptr(i))) = ptr } var trap *C.wasm_trap_t var ptr *C.wasm_instance_t err := C.wasmtime_instance_new( store.ptr(), module.ptr(), - importsRawPtr, - C.size_t(len(imports)), + &importsRaw, &ptr, &trap, ) runtime.KeepAlive(store) runtime.KeepAlive(module) - runtime.KeepAlive(imports) - runtime.KeepAlive(importsRaw) + C.wasm_extern_vec_delete(&importsRaw) if err != nil { return nil, mkError(err) } diff --git a/shims.c b/shims.c index 63e5c00..a366718 100644 --- a/shims.c +++ b/shims.c @@ -6,19 +6,19 @@ __thread size_t caller_id; static wasm_trap_t* trampoline( const wasmtime_caller_t *caller, void *env, - const wasm_val_t *args, - wasm_val_t *results + const wasm_val_vec_t *args, + wasm_val_vec_t *results ) { - return goTrampolineNew(caller_id, (wasmtime_caller_t*) caller, (size_t) env, (wasm_val_t*) args, results); + return goTrampolineNew(caller_id, (wasmtime_caller_t*) caller, (size_t) env, (wasm_val_vec_t*) args, results); } static wasm_trap_t* wrap_trampoline( const wasmtime_caller_t *caller, void *env, - const wasm_val_t *args, - wasm_val_t *results + const wasm_val_vec_t *args, + wasm_val_vec_t *results ) { - return goTrampolineWrap(caller_id, (wasmtime_caller_t*) caller, (size_t) env, (wasm_val_t*) args, results); + return goTrampolineWrap(caller_id, (wasmtime_caller_t*) caller, (size_t) env, (wasm_val_vec_t*) args, results); } wasm_func_t *c_func_new_with_env(wasm_store_t *store, wasm_functype_t *ty, size_t env, int wrap) { @@ -29,16 +29,14 @@ wasm_func_t *c_func_new_with_env(wasm_store_t *store, wasm_functype_t *ty, size_ wasmtime_error_t *go_wasmtime_func_call( wasm_func_t *func, - const wasm_val_t *args, - size_t num_args, - wasm_val_t *results, - size_t num_results, + const wasm_val_vec_t *args, + wasm_val_vec_t *results, wasm_trap_t **trap, size_t go_id ) { size_t prev_caller_id = caller_id; caller_id = go_id; - wasmtime_error_t *ret = wasmtime_func_call(func, args, num_args, results, num_results, trap); + wasmtime_error_t *ret = wasmtime_func_call(func, args, results, trap); caller_id = prev_caller_id; return ret; } diff --git a/shims.h b/shims.h index 0c4e7b6..ea57ad4 100644 --- a/shims.h +++ b/shims.h @@ -23,10 +23,8 @@ void go_externref_new_with_finalizer( ); wasmtime_error_t *go_wasmtime_func_call( wasm_func_t *func, - const wasm_val_t *args, - size_t num_args, - wasm_val_t *results, - size_t num_results, + const wasm_val_vec_t *args, + wasm_val_vec_t *results, wasm_trap_t **trap, size_t go_id );