diff --git a/crates/runtime/src/mpk/disabled.rs b/crates/runtime/src/mpk/disabled.rs index 7f6de55bc3ea..fbfd95e088ab 100644 --- a/crates/runtime/src/mpk/disabled.rs +++ b/crates/runtime/src/mpk/disabled.rs @@ -13,6 +13,10 @@ pub fn keys(_: usize) -> &'static [ProtectionKey] { } pub fn allow(_: ProtectionMask) {} +pub fn current_mask() -> ProtectionMask { + ProtectionMask +} + #[derive(Clone, Copy, Debug)] pub enum ProtectionKey {} impl ProtectionKey { diff --git a/crates/runtime/src/mpk/enabled.rs b/crates/runtime/src/mpk/enabled.rs index 2ce2de8ea3b4..23b7fa85f7ea 100644 --- a/crates/runtime/src/mpk/enabled.rs +++ b/crates/runtime/src/mpk/enabled.rs @@ -55,6 +55,11 @@ pub fn allow(mask: ProtectionMask) { log::trace!("PKRU change: {:#034b} => {:#034b}", previous, pkru::read()); } +/// Retrieve the current protection mask. +pub fn current_mask() -> ProtectionMask { + ProtectionMask(pkru::read()) +} + /// An MPK protection key. /// /// The expected usage is: diff --git a/crates/runtime/src/mpk/mod.rs b/crates/runtime/src/mpk/mod.rs index f60b3880b09b..55f822aa5789 100644 --- a/crates/runtime/src/mpk/mod.rs +++ b/crates/runtime/src/mpk/mod.rs @@ -34,10 +34,10 @@ cfg_if::cfg_if! { mod enabled; mod pkru; mod sys; - pub use enabled::{allow, is_supported, keys, ProtectionKey, ProtectionMask}; + pub use enabled::{allow, current_mask, is_supported, keys, ProtectionKey, ProtectionMask}; } else { mod disabled; - pub use disabled::{allow, is_supported, keys, ProtectionKey, ProtectionMask}; + pub use disabled::{allow, current_mask, is_supported, keys, ProtectionKey, ProtectionMask}; } } diff --git a/crates/wasmtime/src/store.rs b/crates/wasmtime/src/store.rs index 5a5c1e94e9cd..9c26b004f221 100644 --- a/crates/wasmtime/src/store.rs +++ b/crates/wasmtime/src/store.rs @@ -95,10 +95,11 @@ use std::ptr; use std::sync::atomic::AtomicU64; use std::sync::Arc; use std::task::{Context, Poll}; +use wasmtime_runtime::mpk::{self, ProtectionKey, ProtectionMask}; use wasmtime_runtime::{ - mpk::ProtectionKey, ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle, - ModuleInfo, OnDemandInstanceAllocator, SignalHandler, StoreBox, StorePtr, VMContext, - VMExternRef, VMExternRefActivationsTable, VMFuncRef, VMRuntimeLimits, WasmFault, + ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle, ModuleInfo, + OnDemandInstanceAllocator, SignalHandler, StoreBox, StorePtr, VMContext, VMExternRef, + VMExternRefActivationsTable, VMFuncRef, VMRuntimeLimits, WasmFault, }; mod context; @@ -1977,7 +1978,14 @@ impl AsyncCx { Poll::Pending => {} } + // In order to prevent this fiber's MPK state from being munged by + // other fibers while it is suspended, we save and restore it once + // once execution resumes. Note that when MPK is not supported, + // these are noops. + let previous_mask = mpk::current_mask(); + mpk::allow(ProtectionMask::all()); (*suspend).suspend(())?; + mpk::allow(previous_mask); } } }