Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cast_{arc,rc} (and slice and try), and {wrap,peel}_{arc,rc}. #132

Merged
merged 4 commits into from
Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
289 changes: 289 additions & 0 deletions src/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use super::*;
use alloc::{
alloc::{alloc_zeroed, Layout},
boxed::Box,
rc::Rc,
sync::Arc,
vec,
vec::Vec,
};
Expand Down Expand Up @@ -315,6 +317,205 @@ pub fn pod_collect_to_vec<
dst
}

/// As [`try_cast_rc`](try_cast_rc), but unwraps for you.
#[inline]
pub fn cast_rc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
input: Rc<A>,
) -> Rc<B> {
try_cast_rc(input).map_err(|(e, _v)| e).unwrap()
}

/// Attempts to cast the content type of a [`Rc`](alloc::rc::Rc).
///
/// On failure you get back an error along with the starting `Rc`.
///
/// The bounds on this function are the same as [`cast_mut`], because a user
/// could call `Rc::get_unchecked_mut` on the output, which could be observable
/// in the input.
///
/// ## Failure
///
/// * The start and end content type of the `Rc` must have the exact same
/// alignment.
/// * The start and end size of the `Rc` must have the exact same size.
#[inline]
pub fn try_cast_rc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
input: Rc<A>,
) -> Result<Rc<B>, (PodCastError, Rc<A>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
Err((PodCastError::SizeMismatch, input))
} else {
// Safety: Rc::from_raw requires size and alignment match, which is met.
let ptr: *const B = Rc::into_raw(input) as *const B;
Ok(unsafe { Rc::from_raw(ptr) })
}
}

/// As [`try_cast_arc`](try_cast_arc), but unwraps for you.
#[inline]
pub fn cast_arc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
input: Arc<A>,
) -> Arc<B> {
try_cast_arc(input).map_err(|(e, _v)| e).unwrap()
}

/// Attempts to cast the content type of a [`Arc`](alloc::sync::Arc).
///
/// On failure you get back an error along with the starting `Arc`.
///
/// The bounds on this function are the same as [`cast_mut`], because a user
/// could call `Rc::get_unchecked_mut` on the output, which could be observable
/// in the input.
///
/// ## Failure
///
/// * The start and end content type of the `Arc` must have the exact same
/// alignment.
/// * The start and end size of the `Arc` must have the exact same size.
#[inline]
pub fn try_cast_arc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Arc<A>,
) -> Result<Arc<B>, (PodCastError, Arc<A>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
Err((PodCastError::SizeMismatch, input))
} else {
// Safety: Arc::from_raw requires size and alignment match, which is met.
let ptr: *const B = Arc::into_raw(input) as *const B;
Ok(unsafe { Arc::from_raw(ptr) })
}
}

/// As [`try_cast_slice_rc`](try_cast_slice_rc), but unwraps for you.
#[inline]
pub fn cast_slice_rc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Rc<[A]>,
) -> Rc<[B]> {
try_cast_slice_rc(input).map_err(|(e, _v)| e).unwrap()
}

/// Attempts to cast the content type of a `Rc<[T]>`.
///
/// On failure you get back an error along with the starting `Rc<[T]>`.
///
/// The bounds on this function are the same as [`cast_mut`], because a user
/// could call `Rc::get_unchecked_mut` on the output, which could be observable
/// in the input.
///
/// ## Failure
///
/// * The start and end content type of the `Rc<[T]>` must have the exact same
/// alignment.
/// * The start and end content size in bytes of the `Rc<[T]>` must be the exact
/// same.
#[inline]
pub fn try_cast_slice_rc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Rc<[A]>,
) -> Result<Rc<[B]>, (PodCastError, Rc<[A]>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
if size_of::<A>() * input.len() % size_of::<B>() != 0 {
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Rc
// NOTE: This is a valid operation because according to the docs of
// std::rc::Rc::from_raw(), the type U that was in the original Rc<U>
// acquired from Rc::into_raw() must have the same size alignment and
// size of the type T in the new Rc<T>. So as long as both the size
// and alignment stay the same, the Rc will remain a valid Rc.
let length = size_of::<A>() * input.len() / size_of::<B>();
let rc_ptr: *const A = Rc::into_raw(input) as *const A;
// Must use ptr::slice_from_raw_parts, because we cannot make an
// intermediate const reference, because it has mutable provenance,
// nor an intermediate mutable reference, because it could be aliased.
let ptr = core::ptr::slice_from_raw_parts(rc_ptr as *const B, length);
Ok(unsafe { Rc::<[B]>::from_raw(ptr) })
}
} else {
let rc_ptr: *const [A] = Rc::into_raw(input);
let ptr: *const [B] = rc_ptr as *const [B];
Ok(unsafe { Rc::<[B]>::from_raw(ptr) })
}
}

/// As [`try_cast_slice_arc`](try_cast_slice_arc), but unwraps for you.
#[inline]
pub fn cast_slice_arc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Arc<[A]>,
) -> Arc<[B]> {
try_cast_slice_arc(input).map_err(|(e, _v)| e).unwrap()
}

/// Attempts to cast the content type of a `Arc<[T]>`.
///
/// On failure you get back an error along with the starting `Arc<[T]>`.
///
/// The bounds on this function are the same as [`cast_mut`], because a user
/// could call `Rc::get_unchecked_mut` on the output, which could be observable
/// in the input.
///
/// ## Failure
///
/// * The start and end content type of the `Arc<[T]>` must have the exact same
/// alignment.
/// * The start and end content size in bytes of the `Arc<[T]>` must be the
/// exact same.
#[inline]
pub fn try_cast_slice_arc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Arc<[A]>,
) -> Result<Arc<[B]>, (PodCastError, Arc<[A]>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
if size_of::<A>() * input.len() % size_of::<B>() != 0 {
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Arc
// NOTE: This is a valid operation because according to the docs of
// std::sync::Arc::from_raw(), the type U that was in the original Arc<U>
// acquired from Arc::into_raw() must have the same size alignment and
// size of the type T in the new Arc<T>. So as long as both the size
// and alignment stay the same, the Arc will remain a valid Arc.
let length = size_of::<A>() * input.len() / size_of::<B>();
let arc_ptr: *const A = Arc::into_raw(input) as *const A;
// Must use ptr::slice_from_raw_parts, because we cannot make an
// intermediate const reference, because it has mutable provenance,
// nor an intermediate mutable reference, because it could be aliased.
let ptr = core::ptr::slice_from_raw_parts(arc_ptr as *const B, length);
Ok(unsafe { Arc::<[B]>::from_raw(ptr) })
}
} else {
let arc_ptr: *const [A] = Arc::into_raw(input);
let ptr: *const [B] = arc_ptr as *const [B];
Ok(unsafe { Arc::<[B]>::from_raw(ptr) })
}
}

/// An extension trait for `TransparentWrapper` and alloc types.
pub trait TransparentWrapperAlloc<Inner: ?Sized>:
TransparentWrapper<Inner>
Expand Down Expand Up @@ -364,6 +565,50 @@ pub trait TransparentWrapperAlloc<Inner: ?Sized>:
}
}

/// Convert an [`Rc`](alloc::rc::Rc) to the inner type into an `Rc` to the
/// wrapper type.
#[inline]
fn wrap_rc(s: Rc<Inner>) -> Rc<Self> {
assert!(size_of::<*mut Inner>() == size_of::<*mut Self>());

unsafe {
// A pointer cast doesn't work here because rustc can't tell that
// the vtables match (because of the `?Sized` restriction relaxation).
// A `transmute` doesn't work because the layout of Rc is unspecified.
//
// SAFETY:
// * The unsafe contract requires that pointers to Inner and Self have
// identical representations, and that the size and alignment of Inner
// and Self are the same, which meets the safety requirements of
// Rc::from_raw
let inner_ptr: *const Inner = Rc::into_raw(s);
let wrapper_ptr: *const Self = transmute!(inner_ptr);
Rc::from_raw(wrapper_ptr)
}
}

/// Convert an [`Arc`](alloc::sync::Arc) to the inner type into an `Arc` to
/// the wrapper type.
#[inline]
fn wrap_arc(s: Arc<Inner>) -> Arc<Self> {
assert!(size_of::<*mut Inner>() == size_of::<*mut Self>());

unsafe {
// A pointer cast doesn't work here because rustc can't tell that
// the vtables match (because of the `?Sized` restriction relaxation).
// A `transmute` doesn't work because the layout of Arc is unspecified.
//
// SAFETY:
// * The unsafe contract requires that pointers to Inner and Self have
// identical representations, and that the size and alignment of Inner
// and Self are the same, which meets the safety requirements of
// Arc::from_raw
let inner_ptr: *const Inner = Arc::into_raw(s);
let wrapper_ptr: *const Self = transmute!(inner_ptr);
Arc::from_raw(wrapper_ptr)
}
}

/// Convert a vec of the wrapper type into a vec of the inner type.
fn peel_vec(s: Vec<Self>) -> Vec<Inner>
where
Expand Down Expand Up @@ -408,5 +653,49 @@ pub trait TransparentWrapperAlloc<Inner: ?Sized>:
Box::from_raw(inner_ptr)
}
}

/// Convert an [`Rc`](alloc::rc::Rc) to the wrapper type into an `Rc` to the
/// inner type.
#[inline]
fn peel_rc(s: Rc<Self>) -> Rc<Inner> {
assert!(size_of::<*mut Inner>() == size_of::<*mut Self>());

unsafe {
// A pointer cast doesn't work here because rustc can't tell that
// the vtables match (because of the `?Sized` restriction relaxation).
// A `transmute` doesn't work because the layout of Rc is unspecified.
//
// SAFETY:
// * The unsafe contract requires that pointers to Inner and Self have
// identical representations, and that the size and alignment of Inner
// and Self are the same, which meets the safety requirements of
// Rc::from_raw
let wrapper_ptr: *const Self = Rc::into_raw(s);
let inner_ptr: *const Inner = transmute!(wrapper_ptr);
Rc::from_raw(inner_ptr)
}
}

/// Convert an [`Arc`](alloc::sync::Arc) to the wrapper type into an `Arc` to
/// the inner type.
#[inline]
fn peel_arc(s: Arc<Self>) -> Arc<Inner> {
assert!(size_of::<*mut Inner>() == size_of::<*mut Self>());

unsafe {
// A pointer cast doesn't work here because rustc can't tell that
// the vtables match (because of the `?Sized` restriction relaxation).
// A `transmute` doesn't work because the layout of Arc is unspecified.
//
// SAFETY:
// * The unsafe contract requires that pointers to Inner and Self have
// identical representations, and that the size and alignment of Inner
// and Self are the same, which meets the safety requirements of
// Arc::from_raw
let wrapper_ptr: *const Self = Arc::into_raw(s);
let inner_ptr: *const Inner = transmute!(wrapper_ptr);
Arc::from_raw(inner_ptr)
}
}
}
impl<I: ?Sized, T: TransparentWrapper<I>> TransparentWrapperAlloc<I> for T {}
15 changes: 15 additions & 0 deletions tests/transparent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ fn test_transparent_wrapper() {
#[cfg(feature = "extern_crate_alloc")]
{
use bytemuck::allocation::TransparentWrapperAlloc;
use std::{rc::Rc, sync::Arc};

let a: Vec<Foreign> = vec![Foreign::default(); 2];

Expand All @@ -92,5 +93,19 @@ fn test_transparent_wrapper() {
assert_eq!(&*e, &0);
let f: Box<Foreign> = Wrapper::peel_box(e);
assert_eq!(&*f, &0);

let g: Rc<Foreign> = Rc::new(Foreign::default());

let h: Rc<Wrapper> = Wrapper::wrap_rc(g);
assert_eq!(&*h, &0);
let i: Rc<Foreign> = Wrapper::peel_rc(h);
assert_eq!(&*i, &0);

let j: Arc<Foreign> = Arc::new(Foreign::default());

let k: Arc<Wrapper> = Wrapper::wrap_arc(j);
assert_eq!(&*k, &0);
let l: Arc<Foreign> = Wrapper::peel_arc(k);
assert_eq!(&*l, &0);
}
}