diff --git a/Cargo.toml b/Cargo.toml index 891acf4dc..e8292cb4a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,14 +29,16 @@ edition = "2018" bit_field = "0.10.1" bitflags = "1.3.2" volatile = "0.4.4" +rustversion = "1.0.5" [features] default = [ "nightly", "instructions" ] instructions = [] -nightly = [ "const_fn", "abi_x86_interrupt", "asm_const", "doc_cfg" ] +nightly = [ "const_fn", "step_trait", "abi_x86_interrupt", "asm_const", "doc_cfg" ] abi_x86_interrupt = [] const_fn = [] asm_const = [] +step_trait = [] doc_cfg = [] [package.metadata.release] diff --git a/src/addr.rs b/src/addr.rs index 74f677453..09ec65328 100644 --- a/src/addr.rs +++ b/src/addr.rs @@ -1,12 +1,19 @@ //! Physical and virtual addresses manipulation +#[cfg(feature = "step_trait")] +use core::convert::TryFrom; use core::fmt; +#[cfg(feature = "step_trait")] +use core::iter::Step; use core::ops::{Add, AddAssign, Sub, SubAssign}; use crate::structures::paging::page_table::PageTableLevel; use crate::structures::paging::{PageOffset, PageTableIndex}; use bit_field::BitField; +#[cfg(feature = "step_trait")] +const ADDRESS_SPACE_SIZE: u64 = 0x1_0000_0000_0000; + /// A canonical 64-bit virtual memory address. /// /// This is a wrapper type around an `u64`, so it is always 8 bytes, even when compiled @@ -40,9 +47,18 @@ pub struct PhysAddr(u64); /// a valid sign extension and are not null either. So automatic sign extension would have /// overwritten possibly meaningful bits. This likely indicates a bug, for example an invalid /// address calculation. -#[derive(Debug)] +/// +/// Contains the invalid address. pub struct VirtAddrNotValid(pub u64); +impl core::fmt::Debug for VirtAddrNotValid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("VirtAddrNotValid") + .field(&format_args!("{:#x}", self.0)) + .finish() + } +} + impl VirtAddr { /// Creates a new canonical virtual address. /// @@ -70,7 +86,7 @@ impl VirtAddr { match addr.get_bits(47..64) { 0 | 0x1ffff => Ok(VirtAddr(addr)), // address is canonical 1 => Ok(VirtAddr::new_truncate(addr)), // address needs sign extension - other => Err(VirtAddrNotValid(other)), + _ => Err(VirtAddrNotValid(addr)), } } @@ -322,12 +338,81 @@ impl Sub for VirtAddr { } } +#[cfg(feature = "step_trait")] +impl Step for VirtAddr { + fn steps_between(start: &Self, end: &Self) -> Option { + let mut steps = end.0.checked_sub(start.0)?; + + // Check if we jumped the gap. + if end.0.get_bit(47) && !start.0.get_bit(47) { + steps = steps.checked_sub(0xffff_0000_0000_0000).unwrap(); + } + + usize::try_from(steps).ok() + } + + fn forward_checked(start: Self, count: usize) -> Option { + let offset = u64::try_from(count).ok()?; + if offset > ADDRESS_SPACE_SIZE { + return None; + } + + let mut addr = start.0.checked_add(offset)?; + + match addr.get_bits(47..) { + 0x1 => { + // Jump the gap by sign extending the 47th bit. + addr.set_bits(47.., 0x1ffff); + } + 0x2 => { + // Address overflow + return None; + } + _ => {} + } + + Some(Self::new(addr)) + } + + fn backward_checked(start: Self, count: usize) -> Option { + let offset = u64::try_from(count).ok()?; + if offset > ADDRESS_SPACE_SIZE { + return None; + } + + let mut addr = start.0.checked_sub(offset)?; + + match addr.get_bits(47..) { + 0x1fffe => { + // Jump the gap by sign extending the 47th bit. + addr.set_bits(47.., 0); + } + 0x1fffd => { + // Address underflow + return None; + } + _ => {} + } + + Some(Self::new(addr)) + } +} + /// A passed `u64` was not a valid physical address. /// /// This means that bits 52 to 64 were not all null. -#[derive(Debug)] +/// +/// Contains the invalid address. pub struct PhysAddrNotValid(pub u64); +impl core::fmt::Debug for PhysAddrNotValid { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("PhysAddrNotValid") + .field(&format_args!("{:#x}", self.0)) + .finish() + } +} + impl PhysAddr { /// Creates a new physical address. /// @@ -367,7 +452,7 @@ impl PhysAddr { pub fn try_new(addr: u64) -> Result { match addr.get_bits(52..64) { 0 => Ok(PhysAddr(addr)), // address is valid - other => Err(PhysAddrNotValid(other)), + _ => Err(PhysAddrNotValid(addr)), } } @@ -540,8 +625,7 @@ impl Sub for PhysAddr { /// /// Returns the greatest `x` with alignment `align` so that `x <= addr`. /// -/// Panics if the alignment is not a power of two. Without the `const_fn` -/// feature, the panic message will be "index out of bounds". +/// Panics if the alignment is not a power of two. #[inline] pub const fn align_down(addr: u64, align: u64) -> u64 { assert!(align.is_power_of_two(), "`align` must be a power of two"); @@ -552,8 +636,7 @@ pub const fn align_down(addr: u64, align: u64) -> u64 { /// /// Returns the smallest `x` with alignment `align` so that `x >= addr`. /// -/// Panics if the alignment is not a power of two. Without the `const_fn` -/// feature, the panic message will be "index out of bounds". +/// Panics if the alignment is not a power of two. #[inline] pub const fn align_up(addr: u64, align: u64) -> u64 { assert!(align.is_power_of_two(), "`align` must be a power of two"); @@ -577,6 +660,120 @@ mod tests { assert_eq!(VirtAddr::new_truncate(123 << 47), VirtAddr(0xfffff << 47)); } + #[test] + #[cfg(feature = "step_trait")] + fn virtaddr_step_forward() { + assert_eq!(Step::forward(VirtAddr(0), 0), VirtAddr(0)); + assert_eq!(Step::forward(VirtAddr(0), 1), VirtAddr(1)); + assert_eq!( + Step::forward(VirtAddr(0x7fff_ffff_ffff), 1), + VirtAddr(0xffff_8000_0000_0000) + ); + assert_eq!( + Step::forward(VirtAddr(0xffff_8000_0000_0000), 1), + VirtAddr(0xffff_8000_0000_0001) + ); + assert_eq!( + Step::forward_checked(VirtAddr(0xffff_ffff_ffff_ffff), 1), + None + ); + assert_eq!( + Step::forward(VirtAddr(0x7fff_ffff_ffff), 0x1234_5678_9abd), + VirtAddr(0xffff_9234_5678_9abc) + ); + assert_eq!( + Step::forward(VirtAddr(0x7fff_ffff_ffff), 0x8000_0000_0000), + VirtAddr(0xffff_ffff_ffff_ffff) + ); + assert_eq!( + Step::forward(VirtAddr(0x7fff_ffff_ff00), 0x8000_0000_00ff), + VirtAddr(0xffff_ffff_ffff_ffff) + ); + assert_eq!( + Step::forward_checked(VirtAddr(0x7fff_ffff_ff00), 0x8000_0000_0100), + None + ); + assert_eq!( + Step::forward_checked(VirtAddr(0x7fff_ffff_ffff), 0x8000_0000_0001), + None + ); + } + + #[test] + #[cfg(feature = "step_trait")] + fn virtaddr_step_backward() { + assert_eq!(Step::backward(VirtAddr(0), 0), VirtAddr(0)); + assert_eq!(Step::backward_checked(VirtAddr(0), 1), None); + assert_eq!(Step::backward(VirtAddr(1), 1), VirtAddr(0)); + assert_eq!( + Step::backward(VirtAddr(0xffff_8000_0000_0000), 1), + VirtAddr(0x7fff_ffff_ffff) + ); + assert_eq!( + Step::backward(VirtAddr(0xffff_8000_0000_0001), 1), + VirtAddr(0xffff_8000_0000_0000) + ); + assert_eq!( + Step::backward(VirtAddr(0xffff_9234_5678_9abc), 0x1234_5678_9abd), + VirtAddr(0x7fff_ffff_ffff) + ); + assert_eq!( + Step::backward(VirtAddr(0xffff_8000_0000_0000), 0x8000_0000_0000), + VirtAddr(0) + ); + assert_eq!( + Step::backward(VirtAddr(0xffff_8000_0000_0000), 0x7fff_ffff_ff01), + VirtAddr(0xff) + ); + assert_eq!( + Step::backward_checked(VirtAddr(0xffff_8000_0000_0000), 0x8000_0000_0001), + None + ); + } + + #[test] + #[cfg(feature = "step_trait")] + fn virtaddr_steps_between() { + assert_eq!(Step::steps_between(&VirtAddr(0), &VirtAddr(0)), Some(0)); + assert_eq!(Step::steps_between(&VirtAddr(0), &VirtAddr(1)), Some(1)); + assert_eq!(Step::steps_between(&VirtAddr(1), &VirtAddr(0)), None); + assert_eq!( + Step::steps_between( + &VirtAddr(0x7fff_ffff_ffff), + &VirtAddr(0xffff_8000_0000_0000) + ), + Some(1) + ); + assert_eq!( + Step::steps_between( + &VirtAddr(0xffff_8000_0000_0000), + &VirtAddr(0x7fff_ffff_ffff) + ), + None + ); + assert_eq!( + Step::steps_between( + &VirtAddr(0xffff_8000_0000_0000), + &VirtAddr(0xffff_8000_0000_0000) + ), + Some(0) + ); + assert_eq!( + Step::steps_between( + &VirtAddr(0xffff_8000_0000_0000), + &VirtAddr(0xffff_8000_0000_0001) + ), + Some(1) + ); + assert_eq!( + Step::steps_between( + &VirtAddr(0xffff_8000_0000_0001), + &VirtAddr(0xffff_8000_0000_0000) + ), + None + ); + } + #[test] pub fn test_align_up() { // align 1 diff --git a/src/lib.rs b/src/lib.rs index 2f14df068..6c32eebd6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,9 @@ #![cfg_attr(not(test), no_std)] #![cfg_attr(feature = "const_fn", feature(const_mut_refs))] // GDT add_entry() -#![cfg_attr(feature = "const_fn", feature(const_fn_fn_ptr_basics))] // IDT new() -#![cfg_attr(feature = "const_fn", feature(const_fn_trait_bound))] // PageSize marker trait #![cfg_attr(feature = "asm_const", feature(asm_const))] #![cfg_attr(feature = "abi_x86_interrupt", feature(abi_x86_interrupt))] +#![cfg_attr(feature = "step_trait", feature(step_trait))] #![cfg_attr(feature = "doc_cfg", feature(doc_cfg))] #![warn(missing_docs)] #![deny(missing_debug_implementations)] @@ -17,37 +16,6 @@ use core::sync::atomic::{AtomicBool, Ordering}; pub use crate::addr::{align_down, align_up, PhysAddr, VirtAddr}; -/// Makes a function const only when `feature = "const_fn"` is enabled. -/// -/// This is needed for const functions with bounds on their generic parameters, -/// such as those in `Page` and `PhysFrame` and many more. -macro_rules! const_fn { - ( - $(#[$attr:meta])* - $sv:vis fn $($fn:tt)* - ) => { - $(#[$attr])* - #[cfg(feature = "const_fn")] - $sv const fn $($fn)* - - $(#[$attr])* - #[cfg(not(feature = "const_fn"))] - $sv fn $($fn)* - }; - ( - $(#[$attr:meta])* - $sv:vis unsafe fn $($fn:tt)* - ) => { - $(#[$attr])* - #[cfg(feature = "const_fn")] - $sv const unsafe fn $($fn)* - - $(#[$attr])* - #[cfg(not(feature = "const_fn"))] - $sv unsafe fn $($fn)* - }; -} - pub mod addr; pub mod instructions; pub mod registers; diff --git a/src/registers/model_specific.rs b/src/registers/model_specific.rs index 503bd57f0..0a390d4e7 100644 --- a/src/registers/model_specific.rs +++ b/src/registers/model_specific.rs @@ -55,6 +55,14 @@ pub struct LStar; #[derive(Debug)] pub struct SFMask; +/// IA32_U_CET: user mode CET configuration +#[derive(Debug)] +pub struct UCet; + +/// IA32_S_CET: supervisor mode CET configuration +#[derive(Debug)] +pub struct SCet; + impl Efer { /// The underlying model specific register. pub const MSR: Msr = Msr(0xC000_0080); @@ -90,6 +98,16 @@ impl SFMask { pub const MSR: Msr = Msr(0xC000_0084); } +impl UCet { + /// The underlying model specific register. + pub const MSR: Msr = Msr(0x6A0); +} + +impl SCet { + /// The underlying model specific register. + pub const MSR: Msr = Msr(0x6A2); +} + bitflags! { /// Flags of the Extended Feature Enable Register. pub struct EferFlags: u64 { @@ -112,12 +130,37 @@ bitflags! { } } +bitflags! { + /// Flags stored in IA32_U_CET and IA32_S_CET (Table-2-2 in Intel SDM Volume + /// 4). The Intel SDM-equivalent names are described in parentheses. + pub struct CetFlags: u64 { + /// Enable shadow stack (SH_STK_EN) + const SS_ENABLE = 1 << 0; + /// Enable WRSS{D,Q}W instructions (WR_SHTK_EN) + const SS_WRITE_ENABLE = 1 << 1; + /// Enable indirect branch tracking (ENDBR_EN) + const IBT_ENABLE = 1 << 2; + /// Enable legacy treatment for indirect branch tracking (LEG_IW_EN) + const IBT_LEGACY_ENABLE = 1 << 3; + /// Enable no-track opcode prefix for indirect branch tracking (NO_TRACK_EN) + const IBT_NO_TRACK_ENABLE = 1 << 4; + /// Disable suppression of CET on legacy compatibility (SUPPRESS_DIS) + const IBT_LEGACY_SUPPRESS_ENABLE = 1 << 5; + /// Enable suppression of indirect branch tracking (SUPPRESS) + const IBT_SUPPRESS_ENABLE = 1 << 10; + /// Is IBT waiting for a branch to return? (read-only, TRACKER) + const IBT_TRACKED = 1 << 11; + } +} + #[cfg(feature = "instructions")] mod x86_64 { use super::*; use crate::addr::VirtAddr; use crate::registers::rflags::RFlags; use crate::structures::gdt::SegmentSelector; + use crate::structures::paging::Page; + use crate::structures::paging::Size4KiB; use crate::PrivilegeLevel; use bit_field::BitField; use core::convert::TryInto; @@ -469,4 +512,74 @@ mod x86_64 { unsafe { msr.write(value.bits()) }; } } + + impl UCet { + /// Read the raw IA32_U_CET. + #[inline] + fn read_raw() -> u64 { + unsafe { Self::MSR.read() } + } + + /// Write the raw IA32_U_CET. + #[inline] + fn write_raw(value: u64) { + let mut msr = Self::MSR; + unsafe { + msr.write(value); + } + } + + /// Read IA32_U_CET. Returns a tuple of the flags and the address to the legacy code page bitmap. + #[inline] + pub fn read() -> (CetFlags, Page) { + let value = Self::read_raw(); + let cet_flags = CetFlags::from_bits_truncate(value); + let legacy_bitmap = + Page::from_start_address(VirtAddr::new(value & !(Page::::SIZE - 1))) + .unwrap(); + + (cet_flags, legacy_bitmap) + } + + /// Write IA32_U_CET. + #[inline] + pub fn write(flags: CetFlags, legacy_bitmap: Page) { + Self::write_raw(flags.bits() | legacy_bitmap.start_address().as_u64()); + } + } + + impl SCet { + /// Read the raw IA32_S_CET. + #[inline] + fn read_raw() -> u64 { + unsafe { Self::MSR.read() } + } + + /// Write the raw IA32_S_CET. + #[inline] + fn write_raw(value: u64) { + let mut msr = Self::MSR; + unsafe { + msr.write(value); + } + } + + /// Read IA32_S_CET. Returns a tuple of the flags and the address to the legacy code page bitmap. + #[inline] + pub fn read() -> (CetFlags, Page) { + let value = Self::read_raw(); + let cet_flags = CetFlags::from_bits_truncate(value); + let legacy_bitmap = + Page::from_start_address(VirtAddr::new(value & !(Page::::SIZE - 1))) + .unwrap(); + + (cet_flags, legacy_bitmap) + } + + /// Write IA32_S_CET. + #[inline] + pub fn write(flags: CetFlags, legacy_bitmap: Page) { + Self::write_raw(flags.bits() | legacy_bitmap.start_address().as_u64()); + } + } } diff --git a/src/structures/gdt.rs b/src/structures/gdt.rs index fdf1ea546..c78e4a03c 100644 --- a/src/structures/gdt.rs +++ b/src/structures/gdt.rs @@ -93,36 +93,34 @@ impl GlobalDescriptorTable { &self.table[..self.next_free] } - const_fn! { - /// Adds the given segment descriptor to the GDT, returning the segment selector. - /// - /// Panics if the GDT has no free entries left. Without the `const_fn` - /// feature, the panic message will be "index out of bounds". - #[inline] - pub fn add_entry(&mut self, entry: Descriptor) -> SegmentSelector { - let index = match entry { - Descriptor::UserSegment(value) => self.push(value), - Descriptor::SystemSegment(value_low, value_high) => { - let index = self.push(value_low); - self.push(value_high); - index - } - }; + /// Adds the given segment descriptor to the GDT, returning the segment selector. + /// + /// Panics if the GDT has no free entries left. + #[inline] + #[cfg_attr(feature = "const_fn", rustversion::attr(all(), const))] + pub fn add_entry(&mut self, entry: Descriptor) -> SegmentSelector { + let index = match entry { + Descriptor::UserSegment(value) => self.push(value), + Descriptor::SystemSegment(value_low, value_high) => { + let index = self.push(value_low); + self.push(value_high); + index + } + }; - let rpl = match entry { - Descriptor::UserSegment(value) => { - if DescriptorFlags::from_bits_truncate(value).contains(DescriptorFlags::DPL_RING_3) - { - PrivilegeLevel::Ring3 - } else { - PrivilegeLevel::Ring0 - } + let rpl = match entry { + Descriptor::UserSegment(value) => { + if DescriptorFlags::from_bits_truncate(value).contains(DescriptorFlags::DPL_RING_3) + { + PrivilegeLevel::Ring3 + } else { + PrivilegeLevel::Ring0 } - Descriptor::SystemSegment(_, _) => PrivilegeLevel::Ring0, - }; + } + Descriptor::SystemSegment(_, _) => PrivilegeLevel::Ring0, + }; - SegmentSelector::new(index as u16, rpl) - } + SegmentSelector::new(index as u16, rpl) } /// Loads the GDT in the CPU using the `lgdt` instruction. This does **not** alter any of the @@ -156,17 +154,16 @@ impl GlobalDescriptorTable { } } - const_fn! { - #[inline] - fn push(&mut self, value: u64) -> usize { - if self.next_free < self.table.len() { - let index = self.next_free; - self.table[index] = value; - self.next_free += 1; - index - } else { - panic!("GDT full"); - } + #[inline] + #[cfg_attr(feature = "const_fn", rustversion::attr(all(), const))] + fn push(&mut self, value: u64) -> usize { + if self.next_free < self.table.len() { + let index = self.next_free; + self.table[index] = value; + self.next_free += 1; + index + } else { + panic!("GDT full"); } } diff --git a/src/structures/idt.rs b/src/structures/idt.rs index be82359d3..86c42ea44 100644 --- a/src/structures/idt.rs +++ b/src/structures/idt.rs @@ -416,38 +416,37 @@ pub struct InterruptDescriptorTable { } impl InterruptDescriptorTable { - const_fn! { - /// Creates a new IDT filled with non-present entries. - #[inline] - pub fn new() -> InterruptDescriptorTable { - InterruptDescriptorTable { - divide_error: Entry::missing(), - debug: Entry::missing(), - non_maskable_interrupt: Entry::missing(), - breakpoint: Entry::missing(), - overflow: Entry::missing(), - bound_range_exceeded: Entry::missing(), - invalid_opcode: Entry::missing(), - device_not_available: Entry::missing(), - double_fault: Entry::missing(), - coprocessor_segment_overrun: Entry::missing(), - invalid_tss: Entry::missing(), - segment_not_present: Entry::missing(), - stack_segment_fault: Entry::missing(), - general_protection_fault: Entry::missing(), - page_fault: Entry::missing(), - reserved_1: Entry::missing(), - x87_floating_point: Entry::missing(), - alignment_check: Entry::missing(), - machine_check: Entry::missing(), - simd_floating_point: Entry::missing(), - virtualization: Entry::missing(), - reserved_2: [Entry::missing(); 8], - vmm_communication_exception: Entry::missing(), - security_exception: Entry::missing(), - reserved_3: Entry::missing(), - interrupts: [Entry::missing(); 256 - 32], - } + /// Creates a new IDT filled with non-present entries. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn new() -> InterruptDescriptorTable { + InterruptDescriptorTable { + divide_error: Entry::missing(), + debug: Entry::missing(), + non_maskable_interrupt: Entry::missing(), + breakpoint: Entry::missing(), + overflow: Entry::missing(), + bound_range_exceeded: Entry::missing(), + invalid_opcode: Entry::missing(), + device_not_available: Entry::missing(), + double_fault: Entry::missing(), + coprocessor_segment_overrun: Entry::missing(), + invalid_tss: Entry::missing(), + segment_not_present: Entry::missing(), + stack_segment_fault: Entry::missing(), + general_protection_fault: Entry::missing(), + page_fault: Entry::missing(), + reserved_1: Entry::missing(), + x87_floating_point: Entry::missing(), + alignment_check: Entry::missing(), + machine_check: Entry::missing(), + simd_floating_point: Entry::missing(), + virtualization: Entry::missing(), + reserved_2: [Entry::missing(); 8], + vmm_communication_exception: Entry::missing(), + security_exception: Entry::missing(), + reserved_3: Entry::missing(), + interrupts: [Entry::missing(); 256 - 32], } } @@ -767,11 +766,15 @@ impl Entry { &mut self.options } + /// Returns the virtual address of this IDT entry's handler function. #[inline] - fn handler_addr(&self) -> u64 { - self.pointer_low as u64 + pub fn handler_addr(&self) -> VirtAddr { + let addr = self.pointer_low as u64 | (self.pointer_middle as u64) << 16 - | (self.pointer_high as u64) << 32 + | (self.pointer_high as u64) << 32; + // addr is a valid VirtAddr, as the pointer members are either all zero, + // or have been set by set_handler_addr (which takes a VirtAddr). + VirtAddr::new_truncate(addr) } } diff --git a/src/structures/paging/frame.rs b/src/structures/paging/frame.rs index d64eb6b4f..64935caee 100644 --- a/src/structures/paging/frame.rs +++ b/src/structures/paging/frame.rs @@ -11,7 +11,8 @@ use core::ops::{Add, AddAssign, Sub, SubAssign}; #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[repr(C)] pub struct PhysFrame { - pub(crate) start_address: PhysAddr, // TODO: remove when start_address() is const + // TODO: Make private when our minimum supported stable Rust version is 1.61 + pub(crate) start_address: PhysAddr, size: PhantomData, } @@ -29,18 +30,17 @@ impl PhysFrame { Ok(unsafe { PhysFrame::from_start_address_unchecked(address) }) } - const_fn! { - /// Returns the frame that starts at the given virtual address. - /// - /// ## Safety - /// - /// The address must be correctly aligned. - #[inline] - pub unsafe fn from_start_address_unchecked(start_address: PhysAddr) -> Self { - PhysFrame { - start_address, - size: PhantomData, - } + /// Returns the frame that starts at the given virtual address. + /// + /// ## Safety + /// + /// The address must be correctly aligned. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub unsafe fn from_start_address_unchecked(start_address: PhysAddr) -> Self { + PhysFrame { + start_address, + size: PhantomData, } } @@ -53,36 +53,32 @@ impl PhysFrame { } } - const_fn! { - /// Returns the start address of the frame. - #[inline] - pub fn start_address(self) -> PhysAddr { - self.start_address - } + /// Returns the start address of the frame. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn start_address(self) -> PhysAddr { + self.start_address } - const_fn! { - /// Returns the size the frame (4KB, 2MB or 1GB). - #[inline] - pub fn size(self) -> u64 { - S::SIZE - } + /// Returns the size the frame (4KB, 2MB or 1GB). + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn size(self) -> u64 { + S::SIZE } - const_fn! { - /// Returns a range of frames, exclusive `end`. - #[inline] - pub fn range(start: PhysFrame, end: PhysFrame) -> PhysFrameRange { - PhysFrameRange { start, end } - } + /// Returns a range of frames, exclusive `end`. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn range(start: PhysFrame, end: PhysFrame) -> PhysFrameRange { + PhysFrameRange { start, end } } - const_fn! { - /// Returns a range of frames, inclusive `end`. - #[inline] - pub fn range_inclusive(start: PhysFrame, end: PhysFrame) -> PhysFrameRangeInclusive { - PhysFrameRangeInclusive { start, end } - } + /// Returns a range of frames, inclusive `end`. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn range_inclusive(start: PhysFrame, end: PhysFrame) -> PhysFrameRangeInclusive { + PhysFrameRangeInclusive { start, end } } } diff --git a/src/structures/paging/page.rs b/src/structures/paging/page.rs index 4a7a9f650..3e2878bfe 100644 --- a/src/structures/paging/page.rs +++ b/src/structures/paging/page.rs @@ -4,6 +4,8 @@ use crate::structures::paging::page_table::PageTableLevel; use crate::structures::paging::PageTableIndex; use crate::VirtAddr; use core::fmt; +#[cfg(feature = "step_trait")] +use core::iter::Step; use core::marker::PhantomData; use core::ops::{Add, AddAssign, Sub, SubAssign}; @@ -75,18 +77,17 @@ impl Page { Ok(Page::containing_address(address)) } - const_fn! { - /// Returns the page that starts at the given virtual address. - /// - /// ## Safety - /// - /// The address must be correctly aligned. - #[inline] - pub unsafe fn from_start_address_unchecked(start_address: VirtAddr) -> Self { - Page { - start_address, - size: PhantomData, - } + /// Returns the page that starts at the given virtual address. + /// + /// ## Safety + /// + /// The address must be correctly aligned. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub unsafe fn from_start_address_unchecked(start_address: VirtAddr) -> Self { + Page { + start_address, + size: PhantomData, } } @@ -99,70 +100,62 @@ impl Page { } } - const_fn! { - /// Returns the start address of the page. - #[inline] - pub fn start_address(self) -> VirtAddr { - self.start_address - } + /// Returns the start address of the page. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn start_address(self) -> VirtAddr { + self.start_address } - const_fn! { - /// Returns the size the page (4KB, 2MB or 1GB). - #[inline] - pub fn size(self) -> u64 { - S::SIZE - } + /// Returns the size the page (4KB, 2MB or 1GB). + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn size(self) -> u64 { + S::SIZE } - const_fn! { - /// Returns the level 4 page table index of this page. - #[inline] - pub fn p4_index(self) -> PageTableIndex { - self.start_address().p4_index() - } + /// Returns the level 4 page table index of this page. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn p4_index(self) -> PageTableIndex { + self.start_address().p4_index() } - const_fn! { - /// Returns the level 3 page table index of this page. - #[inline] - pub fn p3_index(self) -> PageTableIndex { - self.start_address().p3_index() - } + /// Returns the level 3 page table index of this page. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn p3_index(self) -> PageTableIndex { + self.start_address().p3_index() } - const_fn! { - /// Returns the table index of this page at the specified level. - #[inline] - pub fn page_table_index(self, level: PageTableLevel) -> PageTableIndex { - self.start_address().page_table_index(level) - } + /// Returns the table index of this page at the specified level. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn page_table_index(self, level: PageTableLevel) -> PageTableIndex { + self.start_address().page_table_index(level) } - const_fn! { - /// Returns a range of pages, exclusive `end`. - #[inline] - pub fn range(start: Self, end: Self) -> PageRange { - PageRange { start, end } - } + /// Returns a range of pages, exclusive `end`. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn range(start: Self, end: Self) -> PageRange { + PageRange { start, end } } - const_fn! { - /// Returns a range of pages, inclusive `end`. - #[inline] - pub fn range_inclusive(start: Self, end: Self) -> PageRangeInclusive { - PageRangeInclusive { start, end } - } + /// Returns a range of pages, inclusive `end`. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn range_inclusive(start: Self, end: Self) -> PageRangeInclusive { + PageRangeInclusive { start, end } } } impl Page { - const_fn! { - /// Returns the level 2 page table index of this page. - #[inline] - pub fn p2_index(self) -> PageTableIndex { - self.start_address().p2_index() - } + /// Returns the level 2 page table index of this page. + #[inline] + #[rustversion::attr(since(1.61), const)] + pub fn p2_index(self) -> PageTableIndex { + self.start_address().p2_index() } } @@ -274,6 +267,32 @@ impl Sub for Page { } } +#[cfg(feature = "step_trait")] +impl Step for Page { + fn steps_between(start: &Self, end: &Self) -> Option { + Step::steps_between(&start.start_address, &end.start_address) + .map(|steps| steps / S::SIZE as usize) + } + + fn forward_checked(start: Self, count: usize) -> Option { + let count = count.checked_mul(S::SIZE as usize)?; + let start_address = Step::forward_checked(start.start_address, count)?; + Some(Self { + start_address, + size: PhantomData, + }) + } + + fn backward_checked(start: Self, count: usize) -> Option { + let count = count.checked_mul(S::SIZE as usize)?; + let start_address = Step::backward_checked(start.start_address, count)?; + Some(Self { + start_address, + size: PhantomData, + }) + } +} + /// A range of pages with exclusive upper bound. #[derive(Clone, Copy, PartialEq, Eq, Hash)] #[repr(C)] @@ -352,7 +371,16 @@ impl Iterator for PageRangeInclusive { fn next(&mut self) -> Option { if self.start <= self.end { let page = self.start; - self.start += 1; + + // If the end of the inclusive range is the maximum page possible for size S, + // incrementing start until it is greater than the end will cause an integer overflow. + // So instead, in that case we decrement end rather than incrementing start. + let max_page_addr = VirtAddr::new(u64::MAX) - (S::SIZE - 1); + if self.start.start_address() < max_page_addr { + self.start += 1; + } else { + self.end -= 1; + } Some(page) } else { None @@ -410,4 +438,23 @@ mod tests { } assert_eq!(range_inclusive.next(), None); } + + #[test] + pub fn test_page_range_inclusive_overflow() { + let page_size = Size4KiB::SIZE; + let number = 1000; + + let start_addr = VirtAddr::new(u64::MAX).align_down(page_size) - number * page_size; + let start: Page = Page::containing_address(start_addr); + let end = start + number; + + let mut range_inclusive = Page::range_inclusive(start, end); + for i in 0..=number { + assert_eq!( + range_inclusive.next(), + Some(Page::containing_address(start_addr + page_size * i)) + ); + } + assert_eq!(range_inclusive.next(), None); + } }