From 454f372ce840c38083a1926ea1321bbbab782e96 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 18 Feb 2020 13:34:50 -0800 Subject: [PATCH] Add API to statically assert signature of a `Func` This commit add a family of APIs to `Func` named `getN` where `N` is the number of arguments. Each function will attempt to statically assert the signature of a `Func` and, if matching, returns a corresponding closure which can be used to invoke the underlying function. The purpose of this commit is to add a highly optimized way to enter a wasm module, performing type checks up front and avoiding all the costs of boxing and unboxing arguments within a `Val`. In general this should be much more optimized than the previous `call` API for entering a wasm module, if the signature is statically known. --- crates/api/src/func.rs | 192 ++++++++++++++++-- crates/api/src/lib.rs | 2 +- crates/api/tests/func.rs | 90 +++++++- crates/runtime/signalhandlers/Trampolines.cpp | 27 +-- crates/runtime/src/instance.rs | 11 +- crates/runtime/src/lib.rs | 4 +- crates/runtime/src/traphandlers.rs | 63 +++--- crates/wasi/src/lib.rs | 10 +- 8 files changed, 319 insertions(+), 80 deletions(-) diff --git a/crates/api/src/func.rs b/crates/api/src/func.rs index 187f379edd65..e62442b46a20 100644 --- a/crates/api/src/func.rs +++ b/crates/api/src/func.rs @@ -1,10 +1,12 @@ use crate::callable::{NativeCallable, WasmtimeFn, WrappedCallable}; use crate::{Callable, FuncType, Store, Trap, Val, ValType}; use std::fmt; +use std::mem; use std::panic::{self, AssertUnwindSafe}; +use std::ptr; use std::rc::Rc; use wasmtime_jit::InstanceHandle; -use wasmtime_runtime::VMContext; +use wasmtime_runtime::{VMContext, VMFunctionBody}; /// A WebAssembly function which can be called. /// @@ -38,7 +40,7 @@ macro_rules! wrappers { pub fn $name(store: &Store, func: F) -> Func where F: Fn($($args),*) -> R + 'static, - $($args: WasmArg,)* + $($args: WasmTy,)* R: WasmRet, { #[allow(non_snake_case)] @@ -49,7 +51,7 @@ macro_rules! wrappers { ) -> R::Abi where F: Fn($($args),*) -> R + 'static, - $($args: WasmArg,)* + $($args: WasmTy,)* R: WasmRet, { let ret = { @@ -88,6 +90,71 @@ macro_rules! wrappers { )*) } +macro_rules! getters { + ($( + $(#[$doc:meta])* + ($name:ident $(,$args:ident)*) + )*) => ($( + $(#[$doc])* + #[allow(non_snake_case)] + pub fn $name<$($args,)* R>(&self) + -> Option Result> + where + $($args: WasmTy,)* + R: WasmTy, + { + // Verify all the paramers match the expected parameters, and that + // there are no extra parameters... + let mut params = self.ty().params().iter().cloned(); + eprintln!("{:?}", params); + $( + if !$args::matches(&mut params) { + return None; + } + )* + if !params.next().is_none() { + return None; + } + + // ... then do the same for the results... + let mut results = self.ty().results().iter().cloned(); + if !R::matches(&mut results) { + return None; + } + if !results.next().is_none() { + return None; + } + + // ... and then once we've passed the typechecks we can hand out our + // object since our `transmute` below should be safe! + let (address, vmctx) = match self.wasmtime_export() { + wasmtime_runtime::Export::Function { address, vmctx, signature: _} => { + (*address, *vmctx) + } + _ => return None, + }; + Some(move |$($args: $args),*| -> Result { + unsafe { + let f = mem::transmute::< + *const VMFunctionBody, + unsafe extern "C" fn( + *mut VMContext, + *mut VMContext, + $($args::Abi,)* + ) -> R::Abi, + >(address); + let mut ret = None; + $(let $args = $args.into_abi();)* + wasmtime_runtime::catch_traps(vmctx, || { + ret = Some(f(vmctx, ptr::null_mut(), $($args,)*)); + }).map_err(Trap::from_jit)?; + Ok(R::from_abi(vmctx, ret.unwrap())) + } + }) + } + )*) +} + impl Func { /// Creates a new `Func` with the given arguments, typically to create a /// user-defined function to pass as an import to a module. @@ -267,6 +334,50 @@ impl Func { let callable = WasmtimeFn::new(store, instance_handle, export); Func::from_wrapped(store, ty, Rc::new(callable)) } + + getters! { + /// Extracts a natively-callable object from this `Func`, if the + /// signature matches. + /// + /// See the [`Func::get1`] method for more documentation. + (get0) + + /// Extracts a natively-callable object from this `Func`, if the + /// signature matches. + /// + /// This function serves as an optimized version of the [`Func::call`] + /// method if the type signature of a function is statically known to + /// the program. This method is faster than `call` on a few metrics: + /// + /// * Runtime type-checking only happens once, when this method is + /// called. + /// * The result values, if any, aren't boxed into a vector. + /// * Arguments and return values don't go through boxing and unboxing. + /// * No trampolines are used to transfer control flow to/from JIT code, + /// instead this function jumps directly into JIT code. + /// + /// For more information about which Rust types match up to which wasm + /// types, see the documentation on [`Func::wrap1`]. + /// + /// # Return + /// + /// This function will return `None` if the type signature asserted + /// statically does not match the runtime type signature. `Some`, + /// however, will be returned if the underlying function takes one + /// parameter of type `A` and returns the parameter `R`. Currently `R` + /// can either be `()` (no return values) or one wasm type. At this time + /// a multi-value return isn't supported. + /// + /// The returned closure will always return a `Result` and an + /// `Err` is returned if a trap happens while the wasm is executing. + (get1, A) + + /// Extracts a natively-callable object from this `Func`, if the + /// signature matches. + /// + /// See the [`Func::get1`] method for more documentation. + (get2, A, B) + } } impl fmt::Debug for Func { @@ -283,66 +394,105 @@ impl fmt::Debug for Func { /// stable over time. /// /// For more information see [`Func::wrap1`] -pub trait WasmArg { +pub trait WasmTy { #[doc(hidden)] - type Abi; + type Abi: Copy; #[doc(hidden)] fn push(dst: &mut Vec); #[doc(hidden)] + fn matches(tys: impl Iterator) -> bool; + #[doc(hidden)] fn from_abi(vmctx: *mut VMContext, abi: Self::Abi) -> Self; + #[doc(hidden)] + fn into_abi(self) -> Self::Abi; } -impl WasmArg for () { +impl WasmTy for () { type Abi = (); fn push(_dst: &mut Vec) {} + fn matches(_tys: impl Iterator) -> bool { + true + } #[inline] fn from_abi(_vmctx: *mut VMContext, abi: Self::Abi) -> Self { abi } + #[inline] + fn into_abi(self) -> Self::Abi { + self + } } -impl WasmArg for i32 { +impl WasmTy for i32 { type Abi = Self; fn push(dst: &mut Vec) { dst.push(ValType::I32); } + fn matches(mut tys: impl Iterator) -> bool { + tys.next() == Some(ValType::I32) + } #[inline] fn from_abi(_vmctx: *mut VMContext, abi: Self::Abi) -> Self { abi } + #[inline] + fn into_abi(self) -> Self::Abi { + self + } } -impl WasmArg for i64 { +impl WasmTy for i64 { type Abi = Self; fn push(dst: &mut Vec) { dst.push(ValType::I64); } + fn matches(mut tys: impl Iterator) -> bool { + tys.next() == Some(ValType::I64) + } #[inline] fn from_abi(_vmctx: *mut VMContext, abi: Self::Abi) -> Self { abi } + #[inline] + fn into_abi(self) -> Self::Abi { + self + } } -impl WasmArg for f32 { +impl WasmTy for f32 { type Abi = Self; fn push(dst: &mut Vec) { dst.push(ValType::F32); } + fn matches(mut tys: impl Iterator) -> bool { + tys.next() == Some(ValType::F32) + } #[inline] fn from_abi(_vmctx: *mut VMContext, abi: Self::Abi) -> Self { abi } + #[inline] + fn into_abi(self) -> Self::Abi { + self + } } -impl WasmArg for f64 { +impl WasmTy for f64 { type Abi = Self; fn push(dst: &mut Vec) { dst.push(ValType::F64); } + fn matches(mut tys: impl Iterator) -> bool { + tys.next() == Some(ValType::F64) + } #[inline] fn from_abi(_vmctx: *mut VMContext, abi: Self::Abi) -> Self { abi } + #[inline] + fn into_abi(self) -> Self::Abi { + self + } } /// A trait implemented for types which can be returned from closures passed to @@ -359,31 +509,41 @@ pub trait WasmRet { #[doc(hidden)] fn push(dst: &mut Vec); #[doc(hidden)] + fn matches(tys: impl Iterator) -> bool; + #[doc(hidden)] fn into_abi(self) -> Self::Abi; } -impl WasmRet for T { - type Abi = T; +impl WasmRet for T { + type Abi = T::Abi; fn push(dst: &mut Vec) { T::push(dst) } + fn matches(tys: impl Iterator) -> bool { + T::matches(tys) + } + #[inline] fn into_abi(self) -> Self::Abi { - self + T::into_abi(self) } } -impl WasmRet for Result { - type Abi = T; +impl WasmRet for Result { + type Abi = T::Abi; fn push(dst: &mut Vec) { T::push(dst) } + fn matches(tys: impl Iterator) -> bool { + T::matches(tys) + } + #[inline] fn into_abi(self) -> Self::Abi { match self { - Ok(val) => return val, + Ok(val) => return val.into_abi(), Err(trap) => handle_trap(trap), } diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index 911bc18dd43c..4e5bce5b6c5e 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -24,7 +24,7 @@ mod values; pub use crate::callable::Callable; pub use crate::externals::*; pub use crate::frame_info::FrameInfo; -pub use crate::func::{Func, WasmArg, WasmRet}; +pub use crate::func::{Func, WasmRet, WasmTy}; pub use crate::instance::Instance; pub use crate::module::Module; pub use crate::r#ref::{AnyRef, HostInfo, HostRef}; diff --git a/crates/api/tests/func.rs b/crates/api/tests/func.rs index c7eb15eab8b3..7ab8494b0c9f 100644 --- a/crates/api/tests/func.rs +++ b/crates/api/tests/func.rs @@ -1,6 +1,7 @@ use anyhow::Result; +use std::rc::Rc; use std::sync::atomic::{AtomicUsize, Ordering::SeqCst}; -use wasmtime::{Func, Instance, Module, Store, Trap, ValType}; +use wasmtime::{Callable, Func, FuncType, Instance, Module, Store, Trap, Val, ValType}; #[test] fn func_constructors() { @@ -201,3 +202,90 @@ fn trap_import() -> Result<()> { assert_eq!(trap.message(), "foo"); Ok(()) } + +#[test] +fn get_from_wrapper() { + let store = Store::default(); + let f = Func::wrap0(&store, || {}); + assert!(f.get0::<()>().is_some()); + assert!(f.get0::().is_none()); + assert!(f.get1::<(), ()>().is_some()); + assert!(f.get1::().is_none()); + assert!(f.get1::().is_none()); + assert!(f.get2::<(), (), ()>().is_some()); + assert!(f.get2::().is_none()); + assert!(f.get2::().is_none()); + + let f = Func::wrap0(&store, || -> i32 { loop {} }); + assert!(f.get0::().is_some()); + let f = Func::wrap0(&store, || -> f32 { loop {} }); + assert!(f.get0::().is_some()); + let f = Func::wrap0(&store, || -> f64 { loop {} }); + assert!(f.get0::().is_some()); + + let f = Func::wrap1(&store, |_: i32| {}); + assert!(f.get1::().is_some()); + assert!(f.get1::().is_none()); + assert!(f.get1::().is_none()); + assert!(f.get1::().is_none()); + let f = Func::wrap1(&store, |_: i64| {}); + assert!(f.get1::().is_some()); + let f = Func::wrap1(&store, |_: f32| {}); + assert!(f.get1::().is_some()); + let f = Func::wrap1(&store, |_: f64| {}); + assert!(f.get1::().is_some()); +} + +#[test] +fn get_from_signature() { + struct Foo; + impl Callable for Foo { + fn call(&self, _params: &[Val], _results: &mut [Val]) -> Result<(), Trap> { + panic!() + } + } + let store = Store::default(); + let ty = FuncType::new(Box::new([]), Box::new([])); + let f = Func::new(&store, ty, Rc::new(Foo)); + assert!(f.get0::<()>().is_some()); + assert!(f.get0::().is_none()); + assert!(f.get1::().is_none()); + + let ty = FuncType::new(Box::new([ValType::I32]), Box::new([ValType::F64])); + let f = Func::new(&store, ty, Rc::new(Foo)); + assert!(f.get0::<()>().is_none()); + assert!(f.get0::().is_none()); + assert!(f.get1::().is_none()); + assert!(f.get1::().is_some()); +} + +#[test] +fn get_from_module() -> anyhow::Result<()> { + let store = Store::default(); + let module = Module::new( + &store, + r#" + (module + (func (export "f0")) + (func (export "f1") (param i32)) + (func (export "f2") (result i32) + i32.const 0) + ) + + "#, + )?; + let instance = Instance::new(&module, &[])?; + let f0 = instance.get_export("f0").unwrap().func().unwrap(); + assert!(f0.get0::<()>().is_some()); + assert!(f0.get0::().is_none()); + let f1 = instance.get_export("f1").unwrap().func().unwrap(); + assert!(f1.get0::<()>().is_none()); + assert!(f1.get1::().is_some()); + assert!(f1.get1::().is_none()); + let f2 = instance.get_export("f2").unwrap().func().unwrap(); + assert!(f2.get0::<()>().is_none()); + assert!(f2.get0::().is_some()); + assert!(f2.get1::().is_none()); + assert!(f2.get1::().is_none()); + Ok(()) +} diff --git a/crates/runtime/signalhandlers/Trampolines.cpp b/crates/runtime/signalhandlers/Trampolines.cpp index c76db87fd21b..e0702c349db3 100644 --- a/crates/runtime/signalhandlers/Trampolines.cpp +++ b/crates/runtime/signalhandlers/Trampolines.cpp @@ -3,35 +3,16 @@ #include "SignalHandlers.hpp" extern "C" -int WasmtimeCallTrampoline( +int RegisterSetjmp( void **buf_storage, - void *vmctx, - void *caller_vmctx, - void (*trampoline)(void*, void*, void*, void*), - void *body, - void *args) -{ + void (*body)(void*), + void *payload) { jmp_buf buf; if (setjmp(buf) != 0) { return 0; } *buf_storage = &buf; - trampoline(vmctx, caller_vmctx, body, args); - return 1; -} - -extern "C" -int WasmtimeCall( - void **buf_storage, - void *vmctx, - void *caller_vmctx, - void (*body)(void*, void*)) { - jmp_buf buf; - if (setjmp(buf) != 0) { - return 0; - } - *buf_storage = &buf; - body(vmctx, caller_vmctx); + body(payload); return 1; } diff --git a/crates/runtime/src/instance.rs b/crates/runtime/src/instance.rs index 0e851d3705fb..23545d1a4d5b 100644 --- a/crates/runtime/src/instance.rs +++ b/crates/runtime/src/instance.rs @@ -8,7 +8,7 @@ use crate::jit_int::GdbJitImageRegistration; use crate::memory::LinearMemory; use crate::signalhandlers; use crate::table::Table; -use crate::traphandlers::{wasmtime_call, Trap}; +use crate::traphandlers::{catch_traps, Trap}; use crate::vmcontext::{ VMBuiltinFunctionsArray, VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition, VMGlobalImport, VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, @@ -367,8 +367,15 @@ impl Instance { }; // Make the call. - unsafe { wasmtime_call(callee_vmctx, self.vmctx_ptr(), callee_address) } + unsafe { + catch_traps(callee_vmctx, || { + mem::transmute::< + *const VMFunctionBody, + unsafe extern "C" fn(*mut VMContext, *mut VMContext), + >(callee_address)(callee_vmctx, self.vmctx_ptr()) + }) .map_err(InstantiationError::StartTrap) + } } /// Return the offset from the vmctx pointer to its containing Instance. diff --git a/crates/runtime/src/lib.rs b/crates/runtime/src/lib.rs index 188265a74736..24628c446e47 100644 --- a/crates/runtime/src/lib.rs +++ b/crates/runtime/src/lib.rs @@ -44,7 +44,9 @@ pub use crate::mmap::Mmap; pub use crate::sig_registry::SignatureRegistry; pub use crate::trap_registry::{TrapDescription, TrapRegistration, TrapRegistry}; pub use crate::traphandlers::resume_panic; -pub use crate::traphandlers::{raise_user_trap, wasmtime_call, wasmtime_call_trampoline, Trap}; +pub use crate::traphandlers::{ + catch_traps, raise_user_trap, wasmtime_call_trampoline, Trap, +}; pub use crate::vmcontext::{ VMCallerCheckedAnyfunc, VMContext, VMFunctionBody, VMFunctionImport, VMGlobalDefinition, VMGlobalImport, VMInvokeArgument, VMMemoryDefinition, VMMemoryImport, VMSharedSignatureIndex, diff --git a/crates/runtime/src/traphandlers.rs b/crates/runtime/src/traphandlers.rs index c76a37dc97cc..ba05cb37eb89 100644 --- a/crates/runtime/src/traphandlers.rs +++ b/crates/runtime/src/traphandlers.rs @@ -9,23 +9,15 @@ use std::any::Any; use std::cell::Cell; use std::error::Error; use std::fmt; +use std::mem; use std::ptr; use wasmtime_environ::ir; extern "C" { - fn WasmtimeCallTrampoline( + fn RegisterSetjmp( jmp_buf: *mut *const u8, - vmctx: *mut u8, - caller_vmctx: *mut u8, - trampoline: *const VMFunctionBody, - callee: *const VMFunctionBody, - values_vec: *mut u8, - ) -> i32; - fn WasmtimeCall( - jmp_buf: *mut *const u8, - vmctx: *mut u8, - caller_vmctx: *mut u8, - callee: *const VMFunctionBody, + callback: extern "C" fn(*mut u8), + payload: *mut u8, ) -> i32; fn Unwind(jmp_buf: *const u8) -> !; } @@ -154,33 +146,36 @@ pub unsafe fn wasmtime_call_trampoline( callee: *const VMFunctionBody, values_vec: *mut u8, ) -> Result<(), Trap> { - CallThreadState::new(vmctx).with(|cx| { - WasmtimeCallTrampoline( - cx.jmp_buf.as_ptr(), - vmctx as *mut u8, - caller_vmctx as *mut u8, - trampoline, - callee, - values_vec, - ) + catch_traps(vmctx, || { + mem::transmute::< + _, + extern "C" fn(*mut VMContext, *mut VMContext, *const VMFunctionBody, *mut u8), + >(trampoline)(vmctx, caller_vmctx, callee, values_vec) }) } -/// Call the wasm function pointed to by `callee`, which has no arguments or -/// return values. -pub unsafe fn wasmtime_call( - vmctx: *mut VMContext, - caller_vmctx: *mut VMContext, - callee: *const VMFunctionBody, -) -> Result<(), Trap> { - CallThreadState::new(vmctx).with(|cx| { - WasmtimeCall( +/// Catches any wasm traps that happen within the execution of `closure`, +/// returning them as a `Result`. +/// +/// Highly unsafe since `closure` won't have any dtors run. +pub unsafe fn catch_traps(vmctx: *mut VMContext, mut closure: F) -> Result<(), Trap> +where + F: FnMut(), +{ + return CallThreadState::new(vmctx).with(|cx| { + RegisterSetjmp( cx.jmp_buf.as_ptr(), - vmctx as *mut u8, - caller_vmctx as *mut u8, - callee, + call_closure::, + &mut closure as *mut F as *mut u8, ) - }) + }); + + extern "C" fn call_closure(payload: *mut u8) + where + F: FnMut(), + { + unsafe { (*(payload as *mut F))() } + } } /// Temporary state stored on the stack which is registered in the `tls` module diff --git a/crates/wasi/src/lib.rs b/crates/wasi/src/lib.rs index d09ee7c1f8d3..5f0b24b1272d 100644 --- a/crates/wasi/src/lib.rs +++ b/crates/wasi/src/lib.rs @@ -20,7 +20,7 @@ pub fn is_wasi_module(name: &str) -> bool { /// This is an internal structure used to acquire a handle on the caller's /// wasm memory buffer. /// -/// This exploits how we can implement `WasmArg` for ourselves locally even +/// This exploits how we can implement `WasmTy` for ourselves locally even /// though crates in general should not be doing that. This is a crate in /// the wasmtime project, however, so we should be able to keep up with our own /// changes. @@ -33,11 +33,15 @@ struct WasiCallerMemory { len: usize, } -impl wasmtime::WasmArg for WasiCallerMemory { +impl wasmtime::WasmTy for WasiCallerMemory { type Abi = (); fn push(_dst: &mut Vec) {} + fn matches(_tys: impl Iterator) -> bool { + true + } + fn from_abi(vmctx: *mut wasmtime_runtime::VMContext, _abi: ()) -> Self { unsafe { match wasmtime_runtime::InstanceHandle::from_vmctx(vmctx).lookup("memory") { @@ -56,6 +60,8 @@ impl wasmtime::WasmArg for WasiCallerMemory { } } } + + fn into_abi(self) {} } impl WasiCallerMemory {