Skip to content

Commit

Permalink
WIP: refactor Box/Rc/Arc to reduce code duplication.
Browse files Browse the repository at this point in the history
  • Loading branch information
zachs18 committed Jul 21, 2024
1 parent 02ffd53 commit 6417895
Show file tree
Hide file tree
Showing 2 changed files with 438 additions and 168 deletions.
201 changes: 91 additions & 110 deletions src/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,67 @@ use core::{
mem::ManuallyDrop,
ops::{Deref, DerefMut},
};
use internal::{CastablePointer, OwnedCastablePointer};

unsafe impl<T: ?Sized> CastablePointer<T> for Box<T> {
#[inline]
fn into_raw(self) -> *mut T {
Box::into_raw(self)
}

#[inline]
unsafe fn from_raw(ptr: *mut T) -> Self {
Box::from_raw(ptr)
}
}

unsafe impl<T: ?Sized> CastablePointer<T> for Rc<T> {
#[inline]
fn into_raw(self) -> *mut T {
Rc::into_raw(self) as *mut T
}

#[inline]
unsafe fn from_raw(ptr: *mut T) -> Self {
Rc::from_raw(ptr)
}
}

#[cfg(target_has_atomic = "ptr")]
unsafe impl<T: ?Sized> CastablePointer<T> for Arc<T> {
#[inline]
fn into_raw(self) -> *mut T {
Arc::into_raw(self) as *mut T
}

#[inline]
unsafe fn from_raw(ptr: *mut T) -> Self {
Arc::from_raw(ptr)
}
}

// Safety: This is a valid implementation because according to the docs of
// `std::alloc::GlobalAlloc::dealloc()`, the `Layout` that was used to alloc the
// block must be the same `Layout` that is used to dealloc the block. Luckily,
// `Layout` only stores two things, the alignment, and the size in bytes. So as
// long as both of those stay the same, the Layout will remain a valid input to
// dealloc. This matches the requirements of `OwnedCastablePointer`.
unsafe impl<T: ?Sized> OwnedCastablePointer for Box<T> {}
// Safety: This is a valid implementation because according to the docs of
// std::rc::Rc::from_raw(), the type U that was in the original Rc<U> acquired
// from Rc::into_raw() must have the same size and alignment of the type T in
// the new Rc<T>. So as long as both the size and alignment stay the same, the
// Arc will remain a valid Arc. This matches the requirements of
// `OwnedCastablePointer`.
unsafe impl<T: ?Sized> OwnedCastablePointer for Rc<T> {}
// Safety: This is a valid implementation because according to the docs of
// std::sync::Arc::from_raw(), the type U that was in the original Arc<U>
// acquired from Arc::into_raw() must have the same size and alignment of the
// type T in the new Arc<T>. So as long as both the size and alignment stay the
// same, the Arc will remain a valid Arc. This matches the requirements of
// `OwnedCastablePointer`.
#[cfg(target_has_atomic = "ptr")]
unsafe impl<T: ?Sized> OwnedCastablePointer for Arc<T> {}

/// As [`try_cast_box`], but unwraps for you.
#[inline]
Expand All @@ -44,15 +105,9 @@ pub fn cast_box<A: NoUninit, B: AnyBitPattern>(input: Box<A>) -> Box<B> {
pub fn try_cast_box<A: NoUninit, B: AnyBitPattern>(
input: Box<A>,
) -> Result<Box<B>, (PodCastError, Box<A>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
Err((PodCastError::SizeMismatch, input))
} else {
// Note(Lokathor): This is much simpler than with the Vec casting!
let ptr: *mut B = Box::into_raw(input) as *mut B;
Ok(unsafe { Box::from_raw(ptr) })
}
// Safety: We uphold that A: NoUninit and B: AnyBitPattern, so this cast is
// safe.
unsafe { internal::try_cast_owned_ptr(input) }
}

/// Allocates a `Box<T>` with all of the contents being zeroed out.
Expand Down Expand Up @@ -169,33 +224,9 @@ pub fn cast_slice_box<A: NoUninit, B: AnyBitPattern>(
pub fn try_cast_slice_box<A: NoUninit, B: AnyBitPattern>(
input: Box<[A]>,
) -> Result<Box<[B]>, (PodCastError, Box<[A]>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
if size_of::<A>() * input.len() % size_of::<B>() != 0 {
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Box
// NOTE: This is a valid operation because according to the docs of
// std::alloc::GlobalAlloc::dealloc(), the Layout that was used to alloc
// the block must be the same Layout that is used to dealloc the block.
// Luckily, Layout only stores two things, the alignment, and the size in
// bytes. So as long as both of those stay the same, the Layout will
// remain a valid input to dealloc.
let length = size_of::<A>() * input.len() / size_of::<B>();
let box_ptr: *mut A = Box::into_raw(input) as *mut A;
let ptr: *mut [B] =
unsafe { core::slice::from_raw_parts_mut(box_ptr as *mut B, length) };
Ok(unsafe { Box::<[B]>::from_raw(ptr) })
}
} else {
let box_ptr: *mut [A] = Box::into_raw(input);
let ptr: *mut [B] = box_ptr as *mut [B];
Ok(unsafe { Box::<[B]>::from_raw(ptr) })
}
// Safety: We uphold that A: NoUninit and B: AnyBitPattern, so this cast is
// safe.
unsafe { internal::try_cast_owned_slice_ptr(input) }
}

/// As [`try_cast_vec`], but unwraps for you.
Expand Down Expand Up @@ -324,19 +355,16 @@ pub fn cast_rc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
/// * The start and end content type of the `Rc` must have the exact same
/// alignment.
/// * The start and end size of the `Rc` must have the exact same size.
// FIXME(zachs18): Depending on the exact stabilized requirements of
// Rc::get_mut_unchecked, the bounds on this function could be relaxed to A:
// NoUninit, B: AnyBitPattern.
#[inline]
pub fn try_cast_rc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
input: Rc<A>,
) -> Result<Rc<B>, (PodCastError, Rc<A>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
Err((PodCastError::SizeMismatch, input))
} else {
// Safety: Rc::from_raw requires size and alignment match, which is met.
let ptr: *const B = Rc::into_raw(input) as *const B;
Ok(unsafe { Rc::from_raw(ptr) })
}
// Safety: We uphold that A: NoUninit and B: AnyBitPattern, so this cast is
// safe.
unsafe { internal::try_cast_owned_ptr(input) }
}

/// As [`try_cast_arc`], but unwraps for you.
Expand All @@ -361,6 +389,9 @@ pub fn cast_arc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
/// * The start and end content type of the `Arc` must have the exact same
/// alignment.
/// * The start and end size of the `Arc` must have the exact same size.
// FIXME(zachs18): Depending on the exact stabilized requirements of
// Arc::get_mut_unchecked, the bounds on this function could be relaxed to A:
// NoUninit, B: AnyBitPattern.
#[inline]
#[cfg(target_has_atomic = "ptr")]
pub fn try_cast_arc<
Expand All @@ -369,15 +400,9 @@ pub fn try_cast_arc<
>(
input: Arc<A>,
) -> Result<Arc<B>, (PodCastError, Arc<A>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
Err((PodCastError::SizeMismatch, input))
} else {
// Safety: Arc::from_raw requires size and alignment match, which is met.
let ptr: *const B = Arc::into_raw(input) as *const B;
Ok(unsafe { Arc::from_raw(ptr) })
}
// Safety: We uphold that A: NoUninit and B: AnyBitPattern, so this cast is
// safe.
unsafe { internal::try_cast_owned_ptr(input) }
}

/// As [`try_cast_slice_rc`], but unwraps for you.
Expand Down Expand Up @@ -405,41 +430,19 @@ pub fn cast_slice_rc<
/// alignment.
/// * The start and end content size in bytes of the `Rc<[T]>` must be the exact
/// same.
// FIXME(zachs18): Depending on the exact stabilized requirements of
// Rc::get_mut_unchecked, the bounds on this function could be relaxed to A:
// NoUninit, B: AnyBitPattern.
#[inline]
pub fn try_cast_slice_rc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Rc<[A]>,
) -> Result<Rc<[B]>, (PodCastError, Rc<[A]>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
if size_of::<A>() * input.len() % size_of::<B>() != 0 {
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Rc
// NOTE: This is a valid operation because according to the docs of
// std::rc::Rc::from_raw(), the type U that was in the original Rc<U>
// acquired from Rc::into_raw() must have the same size alignment and
// size of the type T in the new Rc<T>. So as long as both the size
// and alignment stay the same, the Rc will remain a valid Rc.
let length = size_of::<A>() * input.len() / size_of::<B>();
let rc_ptr: *const A = Rc::into_raw(input) as *const A;
// Must use ptr::slice_from_raw_parts, because we cannot make an
// intermediate const reference, because it has mutable provenance,
// nor an intermediate mutable reference, because it could be aliased.
let ptr = core::ptr::slice_from_raw_parts(rc_ptr as *const B, length);
Ok(unsafe { Rc::<[B]>::from_raw(ptr) })
}
} else {
let rc_ptr: *const [A] = Rc::into_raw(input);
let ptr: *const [B] = rc_ptr as *const [B];
Ok(unsafe { Rc::<[B]>::from_raw(ptr) })
}
// Safety: We uphold that A: NoUninit and B: AnyBitPattern, so this cast is
// safe.
unsafe { internal::try_cast_owned_slice_ptr(input) }
}

/// As [`try_cast_slice_arc`], but unwraps for you.
Expand Down Expand Up @@ -468,6 +471,9 @@ pub fn cast_slice_arc<
/// alignment.
/// * The start and end content size in bytes of the `Arc<[T]>` must be the
/// exact same.
// FIXME(zachs18): Depending on the exact stabilized requirements of
// Arc::get_mut_unchecked, the bounds on this function could be relaxed to A:
// NoUninit, B: AnyBitPattern.
#[inline]
#[cfg(target_has_atomic = "ptr")]
pub fn try_cast_slice_arc<
Expand All @@ -476,34 +482,9 @@ pub fn try_cast_slice_arc<
>(
input: Arc<[A]>,
) -> Result<Arc<[B]>, (PodCastError, Arc<[A]>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
if size_of::<A>() * input.len() % size_of::<B>() != 0 {
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Arc
// NOTE: This is a valid operation because according to the docs of
// std::sync::Arc::from_raw(), the type U that was in the original Arc<U>
// acquired from Arc::into_raw() must have the same size alignment and
// size of the type T in the new Arc<T>. So as long as both the size
// and alignment stay the same, the Arc will remain a valid Arc.
let length = size_of::<A>() * input.len() / size_of::<B>();
let arc_ptr: *const A = Arc::into_raw(input) as *const A;
// Must use ptr::slice_from_raw_parts, because we cannot make an
// intermediate const reference, because it has mutable provenance,
// nor an intermediate mutable reference, because it could be aliased.
let ptr = core::ptr::slice_from_raw_parts(arc_ptr as *const B, length);
Ok(unsafe { Arc::<[B]>::from_raw(ptr) })
}
} else {
let arc_ptr: *const [A] = Arc::into_raw(input);
let ptr: *const [B] = arc_ptr as *const [B];
Ok(unsafe { Arc::<[B]>::from_raw(ptr) })
}
// Safety: We uphold that A: NoUninit and B: AnyBitPattern, so this cast is
// safe.
unsafe { internal::try_cast_owned_slice_ptr(input) }
}

/// An extension trait for `TransparentWrapper` and alloc types.
Expand Down
Loading

0 comments on commit 6417895

Please sign in to comment.