Skip to content

Commit

Permalink
Save the context of multiple processes
Browse files Browse the repository at this point in the history
When pausing processes, we now save the user context onto the user stack instead into gsdata.

This means we can hold the context of multiple processes, by having GsData hold mere references to the user stack pointer for each process.

Change-Id: I84f5dbfde4c72b703d1133d2b1a3dfb3a804dbf5
  • Loading branch information
jul-sh committed Jul 1, 2024
1 parent 625064e commit 5664706
Showing 1 changed file with 137 additions and 130 deletions.
267 changes: 137 additions & 130 deletions oak_restricted_kernel/src/syscall/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ mod switch_process;
mod tests;

use alloc::boxed::Box;
use core::{arch::asm, ffi::c_void, mem::offset_of};
use core::{arch::asm, ffi::c_void, mem::offset_of, ptr::addr_of_mut};

use oak_channel::Channel;
use oak_restricted_kernel_interface::{Errno, Syscall};
Expand All @@ -50,57 +50,13 @@ use self::{
};
use crate::mm;

/// State we need to track for system calls.
#[repr(C, align(16))]
#[derive(Debug)]
struct GsData {
/// Kernel stack pointer (what to set in RSP after saving user RSP).
kernel_sp: VirtAddr,

/// User context data
user_ctx: UserContext,
}

#[repr(C, align(16))]
#[derive(Debug, Default)]
struct UserContext {
rsp: u64,
rcx: u64,
rdi: u64,
rsi: u64,
rdx: u64,
rbx: u64,
r8: u64,
r9: u64,
r10: u64,
r11: u64,
r12: u64,
r13: u64,
r14: u64,
r15: u64,
rbp: u64,
ymm: [[u64; 4]; 16],
}

pub fn enable_syscalls(channel: Box<dyn Channel>, dice_data: dice_data::DiceData) {
channel::register(channel);
stdio::register();
key::register();
dice_data::register(dice_data);

// Allocate a stack for the system call handler.
let kernel_sp = mm::allocate_stack();

// Store the gsdata structure in the kernel heap.
// We need the GsData to stick around statically, so we'll leak it here.
let gsdata = Box::leak(Box::new(GsData {
// Stack grows down, so SP points to the end of the page
kernel_sp,
user_ctx: UserContext::default(),
}));

KernelGsBase::write(VirtAddr::from_ptr(gsdata));
GsBase::write(VirtAddr::from_ptr(gsdata));
GsData::setup();

LStar::write(VirtAddr::new(syscall_entrypoint as usize as u64));
unsafe {
Expand Down Expand Up @@ -138,6 +94,46 @@ extern "sysv64" fn syscall_handler(
}
}

const MAX_PROCESSES: usize = 16;

static mut GS_DATA: GsData = GsData {
kernel_sp: VirtAddr::zero(),
user_stack_pointers: [VirtAddr::zero(); MAX_PROCESSES],
current_pid: 0,
};

/// State we need to track for system calls.
#[repr(C, align(16))]
#[derive(Debug)]
struct GsData {
/// Kernel stack pointer (what to set in RSP after saving user RSP).
kernel_sp: VirtAddr,

/// Current process ID.
current_pid: usize,

/// Array of user stack pointers (RSP) for each process, by PID.
user_stack_pointers: [VirtAddr; MAX_PROCESSES],
}

impl GsData {
pub fn setup() {
// Allocate a stack for the system call handler.
let kernel_sp = mm::allocate_stack();

// Store the gsdata structure in the kernel heap.
// We need the GsData to stick around statically, so we'll leak it here.
//
// Safety: This is called during initialization, so we know no other threads are
// accessing GS_DATA
unsafe {
GS_DATA.kernel_sp = kernel_sp;
KernelGsBase::write(VirtAddr::from_ptr(addr_of_mut!(GS_DATA)));
GsBase::write(VirtAddr::from_ptr(addr_of_mut!(GS_DATA)));
};
}
}

/// Main entry point for system calls in the Oak Restricted Kernel.
///
/// As we only support x86-64, we rely on the `SYSCALL`/`SYSRET` mechanism to
Expand Down Expand Up @@ -177,44 +173,57 @@ extern "C" fn syscall_entrypoint() {
// for more details.
unsafe {
asm! {
// Switch to the syscall stack
"swapgs", // switch to kernel GS
// Switch to kernel GS.
"swapgs",

// Save user context to GsData
"mov gs:[{OFFSET_RSP}], rsp",
"mov gs:[{OFFSET_RCX}], rcx",
"mov gs:[{OFFSET_RDI}], rdi",
"mov gs:[{OFFSET_RSI}], rsi",
"mov gs:[{OFFSET_RDX}], rdx",
"mov gs:[{OFFSET_RBX}], rbx",
"mov gs:[{OFFSET_R8}], r8",
"mov gs:[{OFFSET_R9}], r9",
"mov gs:[{OFFSET_R10}], r10",
"mov gs:[{OFFSET_R11}], r11",
"mov gs:[{OFFSET_R12}], r12",
"mov gs:[{OFFSET_R13}], r13",
"mov gs:[{OFFSET_R14}], r14",
"mov gs:[{OFFSET_R15}], r15",
// Save user general-purpose registers to user stack.
// rax is not saved as it holds the syscall identifier.
"push rbx",
"push rcx",
"push rdx",
"push rsi",
"push rdi",
"push rbp",
"push r8",
"push r9",
"push r10",
"push r11", // the syscall instruction saved user RFLAGS into r11
"push r12",
"push r13",
"push r14",
"push r15",

// Save AVX registers to GsData
"vmovups gs:[{OFFSET_YMM} + 0*32], YMM0",
"vmovups gs:[{OFFSET_YMM} + 1*32], YMM1",
"vmovups gs:[{OFFSET_YMM} + 2*32], YMM2",
"vmovups gs:[{OFFSET_YMM} + 3*32], YMM3",
"vmovups gs:[{OFFSET_YMM} + 4*32], YMM4",
"vmovups gs:[{OFFSET_YMM} + 5*32], YMM5",
"vmovups gs:[{OFFSET_YMM} + 6*32], YMM6",
"vmovups gs:[{OFFSET_YMM} + 7*32], YMM7",
"vmovups gs:[{OFFSET_YMM} + 8*32], YMM8",
"vmovups gs:[{OFFSET_YMM} + 9*32], YMM9",
"vmovups gs:[{OFFSET_YMM} + 10*32], YMM10",
"vmovups gs:[{OFFSET_YMM} + 11*32], YMM11",
"vmovups gs:[{OFFSET_YMM} + 12*32], YMM12",
"vmovups gs:[{OFFSET_YMM} + 13*32], YMM13",
"vmovups gs:[{OFFSET_YMM} + 14*32], YMM14",
"vmovups gs:[{OFFSET_YMM} + 15*32], YMM15",
// Save user AVX registers to user stack.
// AVX registers (e.g., ymm0) are not saved using 'push' because they are 256 bits wide
// (or 512 bits in AVX-512) and 'push' is designed for pushing 16-bit, 32-bit, or 64-bit values.
// Additionally, AVX instructions often prefer 32-byte aligned memory access, which may not
// be guaranteed when using 'push' to place values on the stack.
"sub rsp, 512",
"vmovups [rsp + 0*32], ymm0",
"vmovups [rsp + 1*32], ymm1",
"vmovups [rsp + 2*32], ymm2",
"vmovups [rsp + 3*32], ymm3",
"vmovups [rsp + 4*32], ymm4",
"vmovups [rsp + 5*32], ymm5",
"vmovups [rsp + 6*32], ymm6",
"vmovups [rsp + 7*32], ymm7",
"vmovups [rsp + 8*32], ymm8",
"vmovups [rsp + 9*32], ymm9",
"vmovups [rsp + 10*32], ymm10",
"vmovups [rsp + 11*32], ymm11",
"vmovups [rsp + 12*32], ymm12",
"vmovups [rsp + 13*32], ymm13",
"vmovups [rsp + 14*32], ymm14",
"vmovups [rsp + 15*32], ymm15",

"mov rsp, gs:[{OFFSET_KERNEL_STACK_POINTER}]", // switch to kernel stack
// Save user stack pointer to GsData.
"mov r15, gs:[{OFFSET_CURRENT_PID}]",
"shl r15, {POINTER_SIZE_SHIFT}", // Multiply by size of VirtAddr
"add r15, {OFFSET_USER_STACK_POINTERS}",
"mov gs:[r15], rsp",

// Switch to kernel stack.
"mov rsp, gs:[{OFFSET_KERNEL_STACK_POINTER}]",

// Shuffle around register values to match sysv calling convention, and escape into
// proper Rust code from the assembly.
Expand All @@ -230,62 +239,60 @@ extern "C" fn syscall_entrypoint() {
"pop r9",
"add rsp, 8",

// Restore AVX registers from GsData.
"vmovups YMM0, gs:[{OFFSET_YMM} + 0*32]",
"vmovups YMM1, gs:[{OFFSET_YMM} + 1*32]",
"vmovups YMM2, gs:[{OFFSET_YMM} + 2*32]",
"vmovups YMM3, gs:[{OFFSET_YMM} + 3*32]",
"vmovups YMM4, gs:[{OFFSET_YMM} + 4*32]",
"vmovups YMM5, gs:[{OFFSET_YMM} + 5*32]",
"vmovups YMM6, gs:[{OFFSET_YMM} + 6*32]",
"vmovups YMM7, gs:[{OFFSET_YMM} + 7*32]",
"vmovups YMM8, gs:[{OFFSET_YMM} + 8*32]",
"vmovups YMM9, gs:[{OFFSET_YMM} + 9*32]",
"vmovups YMM10, gs:[{OFFSET_YMM} + 10*32]",
"vmovups YMM11, gs:[{OFFSET_YMM} + 11*32]",
"vmovups YMM12, gs:[{OFFSET_YMM} + 12*32]",
"vmovups YMM13, gs:[{OFFSET_YMM} + 13*32]",
"vmovups YMM14, gs:[{OFFSET_YMM} + 14*32]",
"vmovups YMM15, gs:[{OFFSET_YMM} + 15*32]",
// Re-calculate offset of the user stack pointer, the current pid may have changed.
"mov r15, gs:[{OFFSET_CURRENT_PID}]",
"shl r15, {POINTER_SIZE_SHIFT}", // Multiply by size of VirtAddr
"add r15, {OFFSET_USER_STACK_POINTERS}",

// Restore user stack pointer from GsData.
"mov rsp, gs:[r15]",

// Restore scratch registers from GsData
"mov rcx, gs:[{OFFSET_RCX}]",
"mov rdi, gs:[{OFFSET_RDI}]",
"mov rsi, gs:[{OFFSET_RSI}]",
"mov rdx, gs:[{OFFSET_RDX}]",
"mov rbx, gs:[{OFFSET_RBX}]",
"mov r8, gs:[{OFFSET_R8}]",
"mov r9, gs:[{OFFSET_R9}]",
"mov r10, gs:[{OFFSET_R10}]",
"mov r11, gs:[{OFFSET_R11}]", // restore user RFLAGS
"mov r12, gs:[{OFFSET_R12}]",
"mov r13, gs:[{OFFSET_R13}]",
"mov r14, gs:[{OFFSET_R14}]",
"mov r15, gs:[{OFFSET_R15}]",
// Restore user AVX registers from user stack.
"vmovups ymm15, [rsp + 15*32]",
"vmovups ymm14, [rsp + 14*32]",
"vmovups ymm13, [rsp + 13*32]",
"vmovups ymm12, [rsp + 12*32]",
"vmovups ymm11, [rsp + 11*32]",
"vmovups ymm10, [rsp + 10*32]",
"vmovups ymm9, [rsp + 9*32]",
"vmovups ymm8, [rsp + 8*32]",
"vmovups ymm7, [rsp + 7*32]",
"vmovups ymm6, [rsp + 6*32]",
"vmovups ymm5, [rsp + 5*32]",
"vmovups ymm4, [rsp + 4*32]",
"vmovups ymm3, [rsp + 3*32]",
"vmovups ymm2, [rsp + 2*32]",
"vmovups ymm1, [rsp + 1*32]",
"vmovups ymm0, [rsp + 0*32]",
"add rsp, 512",

// Restore user general-purpose registers from user stack.
"pop r15",
"pop r14",
"pop r13",
"pop r12",
"pop r11", // the sysret instruction will copy r11 into RFLAGS
"pop r10",
"pop r9",
"pop r8",
"pop rbp",
"pop rdi",
"pop rsi",
"pop rdx",
"pop rcx",
"pop rbx",
// rax is not restored as it holds the syscall return value.

// Restore user RSP in preparation for SYSRET.
"mov rsp, gs:[{OFFSET_RSP}]",
"swapgs", // restore user GS
// Restore user GS.
"swapgs",

// Back to user code in Ring 3.
"sysretq",
HANDLER = sym syscall_handler,
OFFSET_KERNEL_STACK_POINTER = const(offset_of!(GsData, kernel_sp)),
OFFSET_RSP = const(offset_of!(GsData, user_ctx.rsp)),
OFFSET_RCX = const(offset_of!(GsData, user_ctx.rcx)),
OFFSET_RDI = const(offset_of!(GsData, user_ctx.rdi)),
OFFSET_RSI = const(offset_of!(GsData, user_ctx.rsi)),
OFFSET_RDX = const(offset_of!(GsData, user_ctx.rdx)),
OFFSET_RBX = const(offset_of!(GsData, user_ctx.rbx)),
OFFSET_R8 = const(offset_of!(GsData, user_ctx.r8)),
OFFSET_R9 = const(offset_of!(GsData, user_ctx.r9)),
OFFSET_R10 = const(offset_of!(GsData, user_ctx.r10)),
OFFSET_R11 = const(offset_of!(GsData, user_ctx.r11)),
OFFSET_R12 = const(offset_of!(GsData, user_ctx.r12)),
OFFSET_R13 = const(offset_of!(GsData, user_ctx.r13)),
OFFSET_R14 = const(offset_of!(GsData, user_ctx.r14)),
OFFSET_R15 = const(offset_of!(GsData, user_ctx.r15)),
OFFSET_YMM = const(offset_of!(GsData, user_ctx.ymm)),
OFFSET_USER_STACK_POINTERS = const(offset_of!(GsData, user_stack_pointers)),
OFFSET_CURRENT_PID = const(offset_of!(GsData, current_pid)),
POINTER_SIZE_SHIFT = const(core::mem::size_of::<VirtAddr>().trailing_zeros()),
options(noreturn)
}
}
Expand Down

0 comments on commit 5664706

Please sign in to comment.