diff --git a/tests/test_trap.py b/tests/test_trap.py index b7e3bfbe..221489b7 100644 --- a/tests/test_trap.py +++ b/tests/test_trap.py @@ -78,3 +78,32 @@ def test_frames_no_module(self): self.assertEqual(frames[0].func_index, 0) self.assertEqual(frames[0].func_name, None) self.assertEqual(frames[0].module_name, None) + + def test_wasi_exit(self): + store = Store() + module = Module(store, """ + (module + (import "wasi_snapshot_preview1" "proc_exit" (func $exit (param i32))) + (func (export "exit") (param i32) + local.get 0 + call $exit) + ) + """) + linker = Linker(store) + wasi = WasiConfig() + linker.define_wasi(WasiInstance(store, "wasi_snapshot_preview1", wasi)) + instance = linker.instantiate(module) + exit = instance.exports["exit"] + assert(isinstance(exit, Func)) + + try: + exit(0) + assert(False) + except ExitTrap as e: + assert(e.code == 0) + + try: + exit(1) + assert(False) + except ExitTrap as e: + assert(e.code == 1) diff --git a/wasmtime/__init__.py b/wasmtime/__init__.py index 57484fbf..cf6d9f48 100644 --- a/wasmtime/__init__.py +++ b/wasmtime/__init__.py @@ -22,7 +22,7 @@ from ._wat2wasm import wat2wasm from ._module import Module from ._value import Val, IntoVal -from ._trap import Trap, Frame +from ._trap import Trap, Frame, ExitTrap from ._func import Func, Caller from ._globals import Global from ._table import Table @@ -52,6 +52,7 @@ 'Memory', 'Global', 'Trap', + 'ExitTrap', 'Frame', 'Module', 'Instance', diff --git a/wasmtime/_bindings.py b/wasmtime/_bindings.py index 410c33f3..a512f948 100644 --- a/wasmtime/_bindings.py +++ b/wasmtime/_bindings.py @@ -1972,6 +1972,12 @@ def wasmtime_interrupt_handle_new(store: Any) -> pointer: def wasmtime_interrupt_handle_interrupt(handle: Any) -> None: return _wasmtime_interrupt_handle_interrupt(handle) # type: ignore +_wasmtime_trap_exit_status = dll.wasmtime_trap_exit_status +_wasmtime_trap_exit_status.restype = c_bool +_wasmtime_trap_exit_status.argtypes = [POINTER(wasm_trap_t), POINTER(c_int)] +def wasmtime_trap_exit_status(arg0: Any, status: Any) -> c_bool: + return _wasmtime_trap_exit_status(arg0, status) # type: ignore + _wasmtime_frame_func_name = dll.wasmtime_frame_func_name _wasmtime_frame_func_name.restype = POINTER(wasm_name_t) _wasmtime_frame_func_name.argtypes = [POINTER(wasm_frame_t)] diff --git a/wasmtime/_trap.py b/wasmtime/_trap.py index 742a843a..0ce22ea8 100644 --- a/wasmtime/_trap.py +++ b/wasmtime/_trap.py @@ -1,5 +1,5 @@ from . import _ffi as ffi -from ctypes import byref, POINTER, pointer +from ctypes import byref, POINTER, pointer, c_int from wasmtime import Store, WasmtimeError from typing import Optional, Any, List @@ -24,9 +24,16 @@ def __init__(self, store: Store, message: str): def __from_ptr__(cls, ptr: "pointer[ffi.wasm_trap_t]") -> "Trap": if not isinstance(ptr, POINTER(ffi.wasm_trap_t)): raise TypeError("wrong pointer type") - trap: "Trap" = cls.__new__(cls) - trap.__ptr__ = ptr - return trap + exit_code = c_int(0) + if ffi.wasmtime_trap_exit_status(ptr, byref(exit_code)): + exit_trap: ExitTrap = ExitTrap.__new__(ExitTrap) + exit_trap.__ptr__ = ptr + exit_trap.code = exit_code.value + return exit_trap + else: + trap: Trap = cls.__new__(cls) + trap.__ptr__ = ptr + return trap @property def message(self) -> str: @@ -60,6 +67,23 @@ def __del__(self) -> None: ffi.wasm_trap_delete(self.__ptr__) +class ExitTrap(Trap): + """ + A special type of `Trap` which represents the process exiting via WASI's + `proc_exit` function call. + + Whenever a WASI program exits via `proc_exit` a trap is raised, but the + trap will have this type instead of `Trap`, so you can catch just this + type instead of all traps (if desired). Exit traps have a `code` associated + with them which is the exit code provided at exit. + + Note that `ExitTrap` is a subclass of `Trap`, so if you catch a trap you'll + also catch `ExitTrap`. + """ + code: int + pass + + class Frame: __ptr__: "pointer[ffi.wasm_frame_t]" __owner__: Optional[Any]