Skip to content

Commit

Permalink
Add cast_{arc,rc} (and slice and try), and {wrap,peel}_{arc,rc}. (#…
Browse files Browse the repository at this point in the history
…132)

* Add `allocation::{try_,}cast_{arc,rc}`, and add `{wrap,peel}_{arc,rc}` to `TransparentWrapperAlloc`.

* Avoid intermediate slice reference in `try_cast_slice_{arc,rc}`.

* remove `unsafe` block; run `cargo +nightly fmt` (ignoring files I didn't modify)

* Make `cast_rc` (etc) have the same bounds as `cast_mut`, due to the existence of `Rc::get_mut_unchecked`.
  • Loading branch information
zachs18 authored Sep 1, 2022
1 parent 950a3ed commit 09dd2ff
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 0 deletions.
289 changes: 289 additions & 0 deletions src/allocation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ use super::*;
use alloc::{
alloc::{alloc_zeroed, Layout},
boxed::Box,
rc::Rc,
sync::Arc,
vec,
vec::Vec,
};
Expand Down Expand Up @@ -315,6 +317,205 @@ pub fn pod_collect_to_vec<
dst
}

/// As [`try_cast_rc`](try_cast_rc), but unwraps for you.
#[inline]
pub fn cast_rc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
input: Rc<A>,
) -> Rc<B> {
try_cast_rc(input).map_err(|(e, _v)| e).unwrap()
}

/// Attempts to cast the content type of a [`Rc`](alloc::rc::Rc).
///
/// On failure you get back an error along with the starting `Rc`.
///
/// The bounds on this function are the same as [`cast_mut`], because a user
/// could call `Rc::get_unchecked_mut` on the output, which could be observable
/// in the input.
///
/// ## Failure
///
/// * The start and end content type of the `Rc` must have the exact same
/// alignment.
/// * The start and end size of the `Rc` must have the exact same size.
#[inline]
pub fn try_cast_rc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
input: Rc<A>,
) -> Result<Rc<B>, (PodCastError, Rc<A>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
Err((PodCastError::SizeMismatch, input))
} else {
// Safety: Rc::from_raw requires size and alignment match, which is met.
let ptr: *const B = Rc::into_raw(input) as *const B;
Ok(unsafe { Rc::from_raw(ptr) })
}
}

/// As [`try_cast_arc`](try_cast_arc), but unwraps for you.
#[inline]
pub fn cast_arc<A: NoUninit + AnyBitPattern, B: NoUninit + AnyBitPattern>(
input: Arc<A>,
) -> Arc<B> {
try_cast_arc(input).map_err(|(e, _v)| e).unwrap()
}

/// Attempts to cast the content type of a [`Arc`](alloc::sync::Arc).
///
/// On failure you get back an error along with the starting `Arc`.
///
/// The bounds on this function are the same as [`cast_mut`], because a user
/// could call `Rc::get_unchecked_mut` on the output, which could be observable
/// in the input.
///
/// ## Failure
///
/// * The start and end content type of the `Arc` must have the exact same
/// alignment.
/// * The start and end size of the `Arc` must have the exact same size.
#[inline]
pub fn try_cast_arc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Arc<A>,
) -> Result<Arc<B>, (PodCastError, Arc<A>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
Err((PodCastError::SizeMismatch, input))
} else {
// Safety: Arc::from_raw requires size and alignment match, which is met.
let ptr: *const B = Arc::into_raw(input) as *const B;
Ok(unsafe { Arc::from_raw(ptr) })
}
}

/// As [`try_cast_slice_rc`](try_cast_slice_rc), but unwraps for you.
#[inline]
pub fn cast_slice_rc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Rc<[A]>,
) -> Rc<[B]> {
try_cast_slice_rc(input).map_err(|(e, _v)| e).unwrap()
}

/// Attempts to cast the content type of a `Rc<[T]>`.
///
/// On failure you get back an error along with the starting `Rc<[T]>`.
///
/// The bounds on this function are the same as [`cast_mut`], because a user
/// could call `Rc::get_unchecked_mut` on the output, which could be observable
/// in the input.
///
/// ## Failure
///
/// * The start and end content type of the `Rc<[T]>` must have the exact same
/// alignment.
/// * The start and end content size in bytes of the `Rc<[T]>` must be the exact
/// same.
#[inline]
pub fn try_cast_slice_rc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Rc<[A]>,
) -> Result<Rc<[B]>, (PodCastError, Rc<[A]>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
if size_of::<A>() * input.len() % size_of::<B>() != 0 {
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Rc
// NOTE: This is a valid operation because according to the docs of
// std::rc::Rc::from_raw(), the type U that was in the original Rc<U>
// acquired from Rc::into_raw() must have the same size alignment and
// size of the type T in the new Rc<T>. So as long as both the size
// and alignment stay the same, the Rc will remain a valid Rc.
let length = size_of::<A>() * input.len() / size_of::<B>();
let rc_ptr: *const A = Rc::into_raw(input) as *const A;
// Must use ptr::slice_from_raw_parts, because we cannot make an
// intermediate const reference, because it has mutable provenance,
// nor an intermediate mutable reference, because it could be aliased.
let ptr = core::ptr::slice_from_raw_parts(rc_ptr as *const B, length);
Ok(unsafe { Rc::<[B]>::from_raw(ptr) })
}
} else {
let rc_ptr: *const [A] = Rc::into_raw(input);
let ptr: *const [B] = rc_ptr as *const [B];
Ok(unsafe { Rc::<[B]>::from_raw(ptr) })
}
}

/// As [`try_cast_slice_arc`](try_cast_slice_arc), but unwraps for you.
#[inline]
pub fn cast_slice_arc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Arc<[A]>,
) -> Arc<[B]> {
try_cast_slice_arc(input).map_err(|(e, _v)| e).unwrap()
}

/// Attempts to cast the content type of a `Arc<[T]>`.
///
/// On failure you get back an error along with the starting `Arc<[T]>`.
///
/// The bounds on this function are the same as [`cast_mut`], because a user
/// could call `Rc::get_unchecked_mut` on the output, which could be observable
/// in the input.
///
/// ## Failure
///
/// * The start and end content type of the `Arc<[T]>` must have the exact same
/// alignment.
/// * The start and end content size in bytes of the `Arc<[T]>` must be the
/// exact same.
#[inline]
pub fn try_cast_slice_arc<
A: NoUninit + AnyBitPattern,
B: NoUninit + AnyBitPattern,
>(
input: Arc<[A]>,
) -> Result<Arc<[B]>, (PodCastError, Arc<[A]>)> {
if align_of::<A>() != align_of::<B>() {
Err((PodCastError::AlignmentMismatch, input))
} else if size_of::<A>() != size_of::<B>() {
if size_of::<A>() * input.len() % size_of::<B>() != 0 {
// If the size in bytes of the underlying buffer does not match an exact
// multiple of the size of B, we cannot cast between them.
Err((PodCastError::SizeMismatch, input))
} else {
// Because the size is an exact multiple, we can now change the length
// of the slice and recreate the Arc
// NOTE: This is a valid operation because according to the docs of
// std::sync::Arc::from_raw(), the type U that was in the original Arc<U>
// acquired from Arc::into_raw() must have the same size alignment and
// size of the type T in the new Arc<T>. So as long as both the size
// and alignment stay the same, the Arc will remain a valid Arc.
let length = size_of::<A>() * input.len() / size_of::<B>();
let arc_ptr: *const A = Arc::into_raw(input) as *const A;
// Must use ptr::slice_from_raw_parts, because we cannot make an
// intermediate const reference, because it has mutable provenance,
// nor an intermediate mutable reference, because it could be aliased.
let ptr = core::ptr::slice_from_raw_parts(arc_ptr as *const B, length);
Ok(unsafe { Arc::<[B]>::from_raw(ptr) })
}
} else {
let arc_ptr: *const [A] = Arc::into_raw(input);
let ptr: *const [B] = arc_ptr as *const [B];
Ok(unsafe { Arc::<[B]>::from_raw(ptr) })
}
}

/// An extension trait for `TransparentWrapper` and alloc types.
pub trait TransparentWrapperAlloc<Inner: ?Sized>:
TransparentWrapper<Inner>
Expand Down Expand Up @@ -364,6 +565,50 @@ pub trait TransparentWrapperAlloc<Inner: ?Sized>:
}
}

/// Convert an [`Rc`](alloc::rc::Rc) to the inner type into an `Rc` to the
/// wrapper type.
#[inline]
fn wrap_rc(s: Rc<Inner>) -> Rc<Self> {
assert!(size_of::<*mut Inner>() == size_of::<*mut Self>());

unsafe {
// A pointer cast doesn't work here because rustc can't tell that
// the vtables match (because of the `?Sized` restriction relaxation).
// A `transmute` doesn't work because the layout of Rc is unspecified.
//
// SAFETY:
// * The unsafe contract requires that pointers to Inner and Self have
// identical representations, and that the size and alignment of Inner
// and Self are the same, which meets the safety requirements of
// Rc::from_raw
let inner_ptr: *const Inner = Rc::into_raw(s);
let wrapper_ptr: *const Self = transmute!(inner_ptr);
Rc::from_raw(wrapper_ptr)
}
}

/// Convert an [`Arc`](alloc::sync::Arc) to the inner type into an `Arc` to
/// the wrapper type.
#[inline]
fn wrap_arc(s: Arc<Inner>) -> Arc<Self> {
assert!(size_of::<*mut Inner>() == size_of::<*mut Self>());

unsafe {
// A pointer cast doesn't work here because rustc can't tell that
// the vtables match (because of the `?Sized` restriction relaxation).
// A `transmute` doesn't work because the layout of Arc is unspecified.
//
// SAFETY:
// * The unsafe contract requires that pointers to Inner and Self have
// identical representations, and that the size and alignment of Inner
// and Self are the same, which meets the safety requirements of
// Arc::from_raw
let inner_ptr: *const Inner = Arc::into_raw(s);
let wrapper_ptr: *const Self = transmute!(inner_ptr);
Arc::from_raw(wrapper_ptr)
}
}

/// Convert a vec of the wrapper type into a vec of the inner type.
fn peel_vec(s: Vec<Self>) -> Vec<Inner>
where
Expand Down Expand Up @@ -408,5 +653,49 @@ pub trait TransparentWrapperAlloc<Inner: ?Sized>:
Box::from_raw(inner_ptr)
}
}

/// Convert an [`Rc`](alloc::rc::Rc) to the wrapper type into an `Rc` to the
/// inner type.
#[inline]
fn peel_rc(s: Rc<Self>) -> Rc<Inner> {
assert!(size_of::<*mut Inner>() == size_of::<*mut Self>());

unsafe {
// A pointer cast doesn't work here because rustc can't tell that
// the vtables match (because of the `?Sized` restriction relaxation).
// A `transmute` doesn't work because the layout of Rc is unspecified.
//
// SAFETY:
// * The unsafe contract requires that pointers to Inner and Self have
// identical representations, and that the size and alignment of Inner
// and Self are the same, which meets the safety requirements of
// Rc::from_raw
let wrapper_ptr: *const Self = Rc::into_raw(s);
let inner_ptr: *const Inner = transmute!(wrapper_ptr);
Rc::from_raw(inner_ptr)
}
}

/// Convert an [`Arc`](alloc::sync::Arc) to the wrapper type into an `Arc` to
/// the inner type.
#[inline]
fn peel_arc(s: Arc<Self>) -> Arc<Inner> {
assert!(size_of::<*mut Inner>() == size_of::<*mut Self>());

unsafe {
// A pointer cast doesn't work here because rustc can't tell that
// the vtables match (because of the `?Sized` restriction relaxation).
// A `transmute` doesn't work because the layout of Arc is unspecified.
//
// SAFETY:
// * The unsafe contract requires that pointers to Inner and Self have
// identical representations, and that the size and alignment of Inner
// and Self are the same, which meets the safety requirements of
// Arc::from_raw
let wrapper_ptr: *const Self = Arc::into_raw(s);
let inner_ptr: *const Inner = transmute!(wrapper_ptr);
Arc::from_raw(inner_ptr)
}
}
}
impl<I: ?Sized, T: TransparentWrapper<I>> TransparentWrapperAlloc<I> for T {}
15 changes: 15 additions & 0 deletions tests/transparent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ fn test_transparent_wrapper() {
#[cfg(feature = "extern_crate_alloc")]
{
use bytemuck::allocation::TransparentWrapperAlloc;
use std::{rc::Rc, sync::Arc};

let a: Vec<Foreign> = vec![Foreign::default(); 2];

Expand All @@ -92,5 +93,19 @@ fn test_transparent_wrapper() {
assert_eq!(&*e, &0);
let f: Box<Foreign> = Wrapper::peel_box(e);
assert_eq!(&*f, &0);

let g: Rc<Foreign> = Rc::new(Foreign::default());

let h: Rc<Wrapper> = Wrapper::wrap_rc(g);
assert_eq!(&*h, &0);
let i: Rc<Foreign> = Wrapper::peel_rc(h);
assert_eq!(&*i, &0);

let j: Arc<Foreign> = Arc::new(Foreign::default());

let k: Arc<Wrapper> = Wrapper::wrap_arc(j);
assert_eq!(&*k, &0);
let l: Arc<Foreign> = Wrapper::peel_arc(k);
assert_eq!(&*l, &0);
}
}

0 comments on commit 09dd2ff

Please sign in to comment.