Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mpk: restore PKRU state when a fiber resumes execution #7789

Merged
merged 3 commits into from
Feb 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions crates/runtime/src/mpk/disabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ pub fn keys(_: usize) -> &'static [ProtectionKey] {
}
pub fn allow(_: ProtectionMask) {}

pub fn current_mask() -> ProtectionMask {
ProtectionMask
}

#[derive(Clone, Copy, Debug)]
pub enum ProtectionKey {}
impl ProtectionKey {
Expand Down
5 changes: 5 additions & 0 deletions crates/runtime/src/mpk/enabled.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ pub fn allow(mask: ProtectionMask) {
log::trace!("PKRU change: {:#034b} => {:#034b}", previous, pkru::read());
}

/// Retrieve the current protection mask.
pub fn current_mask() -> ProtectionMask {
ProtectionMask(pkru::read())
}

/// An MPK protection key.
///
/// The expected usage is:
Expand Down
4 changes: 2 additions & 2 deletions crates/runtime/src/mpk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ cfg_if::cfg_if! {
mod enabled;
mod pkru;
mod sys;
pub use enabled::{allow, is_supported, keys, ProtectionKey, ProtectionMask};
pub use enabled::{allow, current_mask, is_supported, keys, ProtectionKey, ProtectionMask};
} else {
mod disabled;
pub use disabled::{allow, is_supported, keys, ProtectionKey, ProtectionMask};
pub use disabled::{allow, current_mask, is_supported, keys, ProtectionKey, ProtectionMask};
}
}

Expand Down
23 changes: 20 additions & 3 deletions crates/wasmtime/src/runtime/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,11 @@ use std::ptr;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::task::{Context, Poll};
use wasmtime_runtime::mpk::{self, ProtectionKey, ProtectionMask};
use wasmtime_runtime::{
mpk::ProtectionKey, ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle,
ModuleInfo, OnDemandInstanceAllocator, SignalHandler, StoreBox, StorePtr, VMContext,
VMExternRef, VMExternRefActivationsTable, VMFuncRef, VMRuntimeLimits, WasmFault,
ExportGlobal, InstanceAllocationRequest, InstanceAllocator, InstanceHandle, ModuleInfo,
OnDemandInstanceAllocator, SignalHandler, StoreBox, StorePtr, VMContext, VMExternRef,
VMExternRefActivationsTable, VMFuncRef, VMRuntimeLimits, WasmFault,
};

mod context;
Expand Down Expand Up @@ -1401,6 +1402,7 @@ impl StoreOpaque {
Some(AsyncCx {
current_suspend: self.async_state.current_suspend.get(),
current_poll_cx: poll_cx_box_ptr,
track_pkey_context_switch: mpk::is_supported() && self.pkey.is_some(),
})
}

Expand Down Expand Up @@ -1938,6 +1940,7 @@ impl<T> StoreContextMut<'_, T> {
pub struct AsyncCx {
current_suspend: *mut *const wasmtime_fiber::Suspend<Result<()>, (), Result<()>>,
current_poll_cx: *mut *mut Context<'static>,
track_pkey_context_switch: bool,
}

#[cfg(feature = "async")]
Expand Down Expand Up @@ -1998,7 +2001,21 @@ impl AsyncCx {
Poll::Pending => {}
}

// In order to prevent this fiber's MPK state from being munged by
// other fibers while it is suspended, we save and restore it once
// once execution resumes. Note that when MPK is not supported,
// these are noops.
let previous_mask = if self.track_pkey_context_switch {
let previous_mask = mpk::current_mask();
mpk::allow(ProtectionMask::all());
previous_mask
} else {
ProtectionMask::all()
};
(*suspend).suspend(())?;
if self.track_pkey_context_switch {
mpk::allow(previous_mask);
}
}
}
}
Expand Down
78 changes: 78 additions & 0 deletions tests/all/async_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,84 @@ async fn async_host_func_with_pooling_stacks() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn async_mpk_protection() -> Result<()> {
let _ = env_logger::try_init();

// Construct a pool with MPK protection enabled; note that the MPK
// protection is configured in `small_pool_config`.
let mut pooling = crate::small_pool_config();
pooling
.total_memories(10)
.total_stacks(2)
.memory_pages(1)
.table_elements(0);
let mut config = Config::new();
config.async_support(true);
config.allocation_strategy(InstanceAllocationStrategy::Pooling(pooling));
config.static_memory_maximum_size(1 << 26);
config.epoch_interruption(true);
let engine = Engine::new(&config)?;

// Craft a module that loops for several iterations and checks whether it
// has access to its memory range (0x0-0x10000).
const WAT: &str = "
(module
(func $start
(local $i i32)
(local.set $i (i32.const 3))
(loop $cont
(drop (i32.load (i32.const 0)))
(drop (i32.load (i32.const 0xfffc)))
(br_if $cont (local.tee $i (i32.sub (local.get $i) (i32.const 1))))))
(memory 1)
(start $start))
";

// Start two instances of the module in separate fibers, `a` and `b`.
async fn run_instance(engine: &Engine, name: &str) -> Instance {
let mut store = Store::new(&engine, ());
store.set_epoch_deadline(0);
store.epoch_deadline_async_yield_and_update(0);
let module = Module::new(store.engine(), WAT).unwrap();
println!("[{name}] building instance");
Instance::new_async(&mut store, &module, &[]).await.unwrap()
}
let mut a = Box::pin(run_instance(&engine, "a"));
let mut b = Box::pin(run_instance(&engine, "b"));

// Alternately poll each instance until completion. This should exercise
// fiber suspensions requiring the `Store` to appropriately save and restore
// the PKRU context between suspensions (see `AsyncCx::block_on`).
for i in 0..10 {
if i % 2 == 0 {
match PollOnce::new(a).await {
Ok(_) => {
println!("[a] done");
break;
}
Err(a_) => {
println!("[a] not done");
a = a_;
}
}
} else {
match PollOnce::new(b).await {
Ok(_) => {
println!("[b] done");
break;
}
Err(b_) => {
println!("[b] not done");
b = b_;
}
}
}
}

Ok(())
}

/// This will execute the `future` provided to completion and each invocation of
/// `poll` for the future will be executed on a separate thread.
pub async fn execute_across_threads<F>(future: F) -> F::Output
Expand Down
Loading