Skip to content

Commit

Permalink
Merge pull request #734 from rust-ndarray/raw-view-cast
Browse files Browse the repository at this point in the history
Simplify ArrayView construction from NonNull<T> and add RawView .cast() method
  • Loading branch information
bluss committed Sep 30, 2019
2 parents ad3340a + ac55d74 commit ecb7643
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 31 deletions.
10 changes: 5 additions & 5 deletions src/impl_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ where
S: Data,
{
debug_assert!(self.pointer_is_inbounds());
unsafe { ArrayView::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) }
unsafe { ArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) }
}

/// Return a read-write view of the array
Expand All @@ -148,7 +148,7 @@ where
S: DataMut,
{
self.ensure_unique();
unsafe { ArrayViewMut::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) }
unsafe { ArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) }
}

/// Return an uniquely owned copy of the array.
Expand Down Expand Up @@ -1313,7 +1313,7 @@ where
/// Return a raw view of the array.
#[inline]
pub fn raw_view(&self) -> RawArrayView<A, D> {
unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) }
unsafe { RawArrayView::new(self.ptr, self.dim.clone(), self.strides.clone()) }
}

/// Return a raw mutable view of the array.
Expand All @@ -1323,7 +1323,7 @@ where
S: RawDataMut,
{
self.try_ensure_unique(); // for RcArray
unsafe { RawArrayViewMut::new_(self.ptr.as_ptr(), self.dim.clone(), self.strides.clone()) }
unsafe { RawArrayViewMut::new(self.ptr, self.dim.clone(), self.strides.clone()) }
}

/// Return the array’s data as a slice, if it is contiguous and in standard order.
Expand Down Expand Up @@ -1620,7 +1620,7 @@ where
Some(st) => st,
None => return None,
};
unsafe { Some(ArrayView::new_(self.ptr.as_ptr(), dim, broadcast_strides)) }
unsafe { Some(ArrayView::new(self.ptr, dim, broadcast_strides)) }
}

/// Swap axes `ax` and `bx`.
Expand Down
87 changes: 75 additions & 12 deletions src/impl_raw_views.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::mem;
use std::ptr::NonNull;

use crate::dimension::{self, stride_offset};
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
use crate::imp_prelude::*;
Expand All @@ -11,16 +14,20 @@ where
///
/// Unsafe because caller is responsible for ensuring that the array will
/// meet all of the invariants of the `ArrayBase` type.
#[inline(always)]
pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
#[inline]
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
RawArrayView {
data: RawViewRepr::new(),
ptr: nonnull_debug_checked_from_ptr(ptr as *mut _),
ptr,
dim,
strides,
}
}

unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides)
}

/// Create an `RawArrayView<A, D>` from shape information and a raw pointer
/// to the elements.
///
Expand Down Expand Up @@ -76,7 +83,7 @@ where
/// ensure that all of the data is valid and choose the correct lifetime.
#[inline]
pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> {
ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides)
ArrayView::new(self.ptr, self.dim, self.strides)
}

/// Split the array view along `axis` and return one array pointer strictly
Expand Down Expand Up @@ -105,6 +112,32 @@ where

(left, right)
}

/// Cast the raw pointer of the raw array view to a different type
///
/// **Panics** if element size is not compatible.
///
/// Lack of panic does not imply it is a valid cast. The cast works the same
/// way as regular raw pointer casts.
///
/// While this method is safe, for the same reason as regular raw pointer
/// casts are safe, access through the produced raw view is only possible
/// in an unsafe block or function.
pub fn cast<B>(self) -> RawArrayView<B, D> {
assert_eq!(
mem::size_of::<B>(),
mem::size_of::<A>(),
"size mismatch in raw view cast"
);
let ptr = self.ptr.cast::<B>();
debug_assert!(
is_aligned(ptr.as_ptr()),
"alignment mismatch in raw view cast"
);
/* Alignment checked with debug assertion: alignment could be dynamically correct,
* and we don't have a check that compiles out for that. */
unsafe { RawArrayView::new(ptr, self.dim, self.strides) }
}
}

impl<A, D> RawArrayViewMut<A, D>
Expand All @@ -115,16 +148,20 @@ where
///
/// Unsafe because caller is responsible for ensuring that the array will
/// meet all of the invariants of the `ArrayBase` type.
#[inline(always)]
pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
#[inline]
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
RawArrayViewMut {
data: RawViewRepr::new(),
ptr: nonnull_debug_checked_from_ptr(ptr),
ptr,
dim,
strides,
}
}

unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides)
}

/// Create an `RawArrayViewMut<A, D>` from shape information and a raw
/// pointer to the elements.
///
Expand Down Expand Up @@ -176,7 +213,7 @@ where
/// Converts to a non-mutable `RawArrayView`.
#[inline]
pub(crate) fn into_raw_view(self) -> RawArrayView<A, D> {
unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) }
}

/// Converts to a read-only view of the array.
Expand All @@ -186,7 +223,7 @@ where
/// ensure that all of the data is valid and choose the correct lifetime.
#[inline]
pub unsafe fn deref_into_view<'a>(self) -> ArrayView<'a, A, D> {
ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides)
ArrayView::new(self.ptr, self.dim, self.strides)
}

/// Converts to a mutable view of the array.
Expand All @@ -196,7 +233,7 @@ where
/// ensure that all of the data is valid and choose the correct lifetime.
#[inline]
pub unsafe fn deref_into_view_mut<'a>(self) -> ArrayViewMut<'a, A, D> {
ArrayViewMut::new_(self.ptr.as_ptr(), self.dim, self.strides)
ArrayViewMut::new(self.ptr, self.dim, self.strides)
}

/// Split the array view along `axis` and return one array pointer strictly
Expand All @@ -207,9 +244,35 @@ where
let (left, right) = self.into_raw_view().split_at(axis, index);
unsafe {
(
Self::new_(left.ptr.as_ptr(), left.dim, left.strides),
Self::new_(right.ptr.as_ptr(), right.dim, right.strides),
Self::new(left.ptr, left.dim, left.strides),
Self::new(right.ptr, right.dim, right.strides),
)
}
}

/// Cast the raw pointer of the raw array view to a different type
///
/// **Panics** if element size is not compatible.
///
/// Lack of panic does not imply it is a valid cast. The cast works the same
/// way as regular raw pointer casts.
///
/// While this method is safe, for the same reason as regular raw pointer
/// casts are safe, access through the produced raw view is only possible
/// in an unsafe block or function.
pub fn cast<B>(self) -> RawArrayViewMut<B, D> {
assert_eq!(
mem::size_of::<B>(),
mem::size_of::<A>(),
"size mismatch in raw view cast"
);
let ptr = self.ptr.cast::<B>();
debug_assert!(
is_aligned(ptr.as_ptr()),
"alignment mismatch in raw view cast"
);
/* Alignment checked with debug assertion: alignment could be dynamically correct,
* and we don't have a check that compiles out for that. */
unsafe { RawArrayViewMut::new(ptr, self.dim, self.strides) }
}
}
35 changes: 27 additions & 8 deletions src/impl_views/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::ptr::NonNull;

use crate::dimension;
use crate::error::ShapeError;
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
Expand Down Expand Up @@ -200,11 +202,11 @@ where

/// Convert the view into an `ArrayViewMut<'b, A, D>` where `'b` is a lifetime
/// outlived by `'a'`.
pub fn reborrow<'b>(mut self) -> ArrayViewMut<'b, A, D>
pub fn reborrow<'b>(self) -> ArrayViewMut<'b, A, D>
where
'a: 'b,
{
unsafe { ArrayViewMut::new_(self.as_mut_ptr(), self.dim, self.strides) }
unsafe { ArrayViewMut::new(self.ptr, self.dim, self.strides) }
}
}

Expand All @@ -217,14 +219,24 @@ where
///
/// Unsafe because: `ptr` must be valid for the given dimension and strides.
#[inline(always)]
pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
if cfg!(debug_assertions) {
assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned.");
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
}
ArrayView {
data: ViewRepr::new(),
ptr: nonnull_debug_checked_from_ptr(ptr as *mut A),
ptr,
dim,
strides,
}
}

/// Unsafe because: `ptr` must be valid for the given dimension and strides.
#[inline]
pub(crate) unsafe fn new_(ptr: *const A, dim: D, strides: D) -> Self {
Self::new(nonnull_debug_checked_from_ptr(ptr as *mut A), dim, strides)
}
}

impl<'a, A, D> ArrayViewMut<'a, A, D>
Expand All @@ -235,17 +247,24 @@ where
///
/// Unsafe because: `ptr` must be valid for the given dimension and strides.
#[inline(always)]
pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
pub(crate) unsafe fn new(ptr: NonNull<A>, dim: D, strides: D) -> Self {
if cfg!(debug_assertions) {
assert!(!ptr.is_null(), "The pointer must be non-null.");
assert!(is_aligned(ptr), "The pointer must be aligned.");
assert!(is_aligned(ptr.as_ptr()), "The pointer must be aligned.");
dimension::max_abs_offset_check_overflow::<A, _>(&dim, &strides).unwrap();
}
ArrayViewMut {
data: ViewRepr::new(),
ptr: nonnull_debug_checked_from_ptr(ptr),
ptr,
dim,
strides,
}
}

/// Create a new `ArrayView`
///
/// Unsafe because: `ptr` must be valid for the given dimension and strides.
#[inline(always)]
pub(crate) unsafe fn new_(ptr: *mut A, dim: D, strides: D) -> Self {
Self::new(nonnull_debug_checked_from_ptr(ptr), dim, strides)
}
}
8 changes: 4 additions & 4 deletions src/impl_views/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ where
where
'a: 'b,
{
unsafe { ArrayView::new_(self.as_ptr(), self.dim, self.strides) }
unsafe { ArrayView::new(self.ptr, self.dim, self.strides) }
}

/// Return the array’s data as a slice, if it is contiguous and in standard order.
Expand All @@ -53,7 +53,7 @@ where

/// Converts to a raw array view.
pub(crate) fn into_raw_view(self) -> RawArrayView<A, D> {
unsafe { RawArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { RawArrayView::new(self.ptr, self.dim, self.strides) }
}
}

Expand Down Expand Up @@ -161,12 +161,12 @@ where
{
// Convert into a read-only view
pub(crate) fn into_view(self) -> ArrayView<'a, A, D> {
unsafe { ArrayView::new_(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { ArrayView::new(self.ptr, self.dim, self.strides) }
}

/// Converts to a mutable raw array view.
pub(crate) fn into_raw_view_mut(self) -> RawArrayViewMut<A, D> {
unsafe { RawArrayViewMut::new_(self.ptr.as_ptr(), self.dim, self.strides) }
unsafe { RawArrayViewMut::new(self.ptr, self.dim, self.strides) }
}

#[inline]
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1517,7 +1517,7 @@ where
let ptr = self.ptr;
let mut strides = dim.clone();
strides.slice_mut().copy_from_slice(self.strides.slice());
unsafe { ArrayView::new_(ptr.as_ptr(), dim, strides) }
unsafe { ArrayView::new(ptr, dim, strides) }
}

fn raw_strides(&self) -> D {
Expand Down
2 changes: 1 addition & 1 deletion src/zip/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ where
type Output = ArrayView<'a, A, E::Dim>;
fn broadcast_unwrap(self, shape: E) -> Self::Output {
let res: ArrayView<'_, A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
unsafe { ArrayView::new_(res.ptr.as_ptr(), res.dim, res.strides) }
unsafe { ArrayView::new(res.ptr, res.dim, res.strides) }
}
private_impl! {}
}
Expand Down
Loading

0 comments on commit ecb7643

Please sign in to comment.