Skip to content

Commit

Permalink
Allow aliasing in ArrayView::from_shape
Browse files Browse the repository at this point in the history
Changes the checks in the ArrayView::from_shape constructor so that it
allows a few more cases: custom strides that lead to overlapping are
allowed.

Before, both ArrayViewMut and ArrayView applied the same check, that the
dimensions and strides must be such that no elements can be reached by
more than one index.

However, this rule only applies for mutable data, for ArrayView we can
allow this kind of aliasing. This is in fact how broadcasting works,
where we use strides to repeat the same array data multiple times.
  • Loading branch information
bluss committed Aug 2, 2024
1 parent e578d58 commit 516a504
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 53 deletions.
123 changes: 76 additions & 47 deletions src/dimension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,21 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
}
}

/// Select how aliasing is checked
///
/// For owned or mutable data:
///
/// The strides must not allow any element to be referenced by two different indices.
///
#[derive(Copy, Clone, PartialEq)]
pub(crate) enum CanIndexCheckMode
{
/// Owned or mutable: No aliasing
OwnedMutable,
/// Aliasing
ReadOnly,
}

/// Checks whether the given data and dimension meet the invariants of the
/// `ArrayBase` type, assuming the strides are created using
/// `dim.default_strides()` or `dim.fortran_strides()`.
Expand All @@ -125,12 +140,13 @@ pub fn size_of_shape_checked<D: Dimension>(dim: &D) -> Result<usize, ShapeError>
/// `A` and in units of bytes between the least address and greatest address
/// accessible by moving along all axes does not exceed `isize::MAX`.
pub(crate) fn can_index_slice_with_strides<A, D: Dimension>(
data: &[A], dim: &D, strides: &Strides<D>,
data: &[A], dim: &D, strides: &Strides<D>, mode: CanIndexCheckMode,
) -> Result<(), ShapeError>
{
if let Strides::Custom(strides) = strides {
can_index_slice(data, dim, strides)
can_index_slice(data, dim, strides, mode)
} else {
// contiguous shapes: never aliasing, mode does not matter
can_index_slice_not_custom(data.len(), dim)
}
}
Expand Down Expand Up @@ -239,15 +255,19 @@ where D: Dimension
/// allocation. (In other words, the pointer to the first element of the array
/// must be computed using `offset_from_low_addr_ptr_to_logical_ptr` so that
/// negative strides are correctly handled.)
pub(crate) fn can_index_slice<A, D: Dimension>(data: &[A], dim: &D, strides: &D) -> Result<(), ShapeError>
///
/// Note, condition (4) is guaranteed to be checked last
pub(crate) fn can_index_slice<A, D: Dimension>(
data: &[A], dim: &D, strides: &D, mode: CanIndexCheckMode,
) -> Result<(), ShapeError>
{
// Check conditions 1 and 2 and calculate `max_offset`.
let max_offset = max_abs_offset_check_overflow::<A, _>(dim, strides)?;
can_index_slice_impl(max_offset, data.len(), dim, strides)
can_index_slice_impl(max_offset, data.len(), dim, strides, mode)
}

fn can_index_slice_impl<D: Dimension>(
max_offset: usize, data_len: usize, dim: &D, strides: &D,
max_offset: usize, data_len: usize, dim: &D, strides: &D, mode: CanIndexCheckMode,
) -> Result<(), ShapeError>
{
// Check condition 3.
Expand All @@ -260,7 +280,7 @@ fn can_index_slice_impl<D: Dimension>(
}

// Check condition 4.
if !is_empty && dim_stride_overlap(dim, strides) {
if !is_empty && mode != CanIndexCheckMode::ReadOnly && dim_stride_overlap(dim, strides) {
return Err(from_kind(ErrorKind::Unsupported));
}

Expand Down Expand Up @@ -782,6 +802,7 @@ mod test
slice_min_max,
slices_intersect,
solve_linear_diophantine_eq,
CanIndexCheckMode,
IntoDimension,
};
use crate::error::{from_kind, ErrorKind};
Expand All @@ -796,11 +817,11 @@ mod test
let v: alloc::vec::Vec<_> = (0..12).collect();
let dim = (2, 3, 2).into_dimension();
let strides = (1, 2, 6).into_dimension();
assert!(super::can_index_slice(&v, &dim, &strides).is_ok());
assert!(super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());

let strides = (2, 4, 12).into_dimension();
assert_eq!(
super::can_index_slice(&v, &dim, &strides),
super::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable),
Err(from_kind(ErrorKind::OutOfBounds))
);
}
Expand Down Expand Up @@ -848,71 +869,79 @@ mod test
#[test]
fn can_index_slice_ix0()
{
can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0()).unwrap();
can_index_slice::<i32, _>(&[], &Ix0(), &Ix0()).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap();
can_index_slice::<i32, _>(&[], &Ix0(), &Ix0(), CanIndexCheckMode::OwnedMutable).unwrap_err();
}

#[test]
fn can_index_slice_ix1()
{
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0)).unwrap();
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1)).unwrap();
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0)).unwrap_err();
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0)).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2)).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize)).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1)).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0)).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1)).unwrap();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize)).unwrap();
let mode = CanIndexCheckMode::OwnedMutable;
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(0), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix1(0), &Ix1(1), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(0), mode).unwrap_err();
can_index_slice::<i32, _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(0), mode).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(2), mode).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(1), &Ix1(-1isize as usize), mode).unwrap();
can_index_slice::<i32, _>(&[1], &Ix1(2), &Ix1(1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(0), mode).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(1), mode).unwrap();
can_index_slice::<i32, _>(&[1, 2], &Ix1(2), &Ix1(-1isize as usize), mode).unwrap();
}

#[test]
fn can_index_slice_ix2()
{
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1)).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1)).unwrap();
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2)).unwrap_err();
can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1)).unwrap();
can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1)).unwrap_err();
let mode = CanIndexCheckMode::OwnedMutable;
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(0, 0), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 0), &Ix2(2, 1), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(0, 0), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 1), &Ix2(2, 1), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 1), mode).unwrap();
can_index_slice::<i32, _>(&[1, 2], &Ix2(1, 2), &Ix2(5, 2), mode).unwrap_err();
can_index_slice::<i32, _>(&[1, 2, 3, 4, 5], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap();
can_index_slice::<i32, _>(&[1, 2, 3, 4], &Ix2(2, 2), &Ix2(3, 1), mode).unwrap_err();

// aliasing strides: ok when readonly
can_index_slice::<i32, _>(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::OwnedMutable).unwrap_err();
can_index_slice::<i32, _>(&[0; 4], &Ix2(2, 2), &Ix2(1, 1), CanIndexCheckMode::ReadOnly).unwrap();
}

#[test]
fn can_index_slice_ix3()
{
can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3)).unwrap();
can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3)).unwrap();
can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap_err();
can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1)).unwrap();
let mode = CanIndexCheckMode::OwnedMutable;
can_index_slice::<i32, _>(&[], &Ix3(0, 0, 1), &Ix3(2, 1, 3), mode).unwrap();
can_index_slice::<i32, _>(&[], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap_err();
can_index_slice::<i32, _>(&[1], &Ix3(1, 1, 1), &Ix3(2, 1, 3), mode).unwrap();
can_index_slice::<i32, _>(&[1; 11], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap_err();
can_index_slice::<i32, _>(&[1; 12], &Ix3(2, 2, 3), &Ix3(6, 3, 1), mode).unwrap();
}

#[test]
fn can_index_slice_zero_size_elem()
{
can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1)).unwrap();
can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1)).unwrap();
can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1)).unwrap();
let mode = CanIndexCheckMode::OwnedMutable;
can_index_slice::<(), _>(&[], &Ix1(0), &Ix1(1), mode).unwrap();
can_index_slice::<(), _>(&[()], &Ix1(1), &Ix1(1), mode).unwrap();
can_index_slice::<(), _>(&[(), ()], &Ix1(2), &Ix1(1), mode).unwrap();

// These might seem okay because the element type is zero-sized, but
// there could be a zero-sized type such that the number of instances
// in existence are carefully controlled.
can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1)).unwrap_err();
can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1)).unwrap_err();
can_index_slice::<(), _>(&[], &Ix1(1), &Ix1(1), mode).unwrap_err();
can_index_slice::<(), _>(&[()], &Ix1(2), &Ix1(1), mode).unwrap_err();

can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0)).unwrap();
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0)).unwrap();
can_index_slice::<(), _>(&[(), ()], &Ix2(2, 1), &Ix2(1, 0), mode).unwrap();
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(0, 0), mode).unwrap();

// This case would be probably be sound, but that's not entirely clear
// and it's not worth the special case code.
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1)).unwrap_err();
can_index_slice::<(), _>(&[], &Ix2(0, 2), &Ix2(2, 1), mode).unwrap_err();
}

quickcheck! {
Expand All @@ -923,8 +952,8 @@ mod test
// Avoid overflow `dim.default_strides()` or `dim.fortran_strides()`.
result.is_err()
} else {
result == can_index_slice(&data, &dim, &dim.default_strides()) &&
result == can_index_slice(&data, &dim, &dim.fortran_strides())
result == can_index_slice(&data, &dim, &dim.default_strides(), CanIndexCheckMode::OwnedMutable) &&
result == can_index_slice(&data, &dim, &dim.fortran_strides(), CanIndexCheckMode::OwnedMutable)
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/impl_constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use num_traits::{One, Zero};
use std::mem;
use std::mem::MaybeUninit;

use crate::dimension;
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::dimension::{self, CanIndexCheckMode};
use crate::error::{self, ShapeError};
use crate::extension::nonnull::nonnull_from_vec_data;
use crate::imp_prelude::*;
Expand Down Expand Up @@ -466,7 +466,7 @@ where
{
let dim = shape.dim;
let is_custom = shape.strides.is_custom();
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides)?;
dimension::can_index_slice_with_strides(&v, &dim, &shape.strides, dimension::CanIndexCheckMode::OwnedMutable)?;
if !is_custom && dim.size() != v.len() {
return Err(error::incompatible_shapes(&Ix1(v.len()), &dim));
}
Expand Down Expand Up @@ -510,7 +510,7 @@ where
unsafe fn from_vec_dim_stride_unchecked(dim: D, strides: D, mut v: Vec<A>) -> Self
{
// debug check for issues that indicates wrong use of this constructor
debug_assert!(dimension::can_index_slice(&v, &dim, &strides).is_ok());
debug_assert!(dimension::can_index_slice(&v, &dim, &strides, CanIndexCheckMode::OwnedMutable).is_ok());

let ptr = nonnull_from_vec_data(&mut v).add(offset_from_low_addr_ptr_to_logical_ptr(&dim, &strides));
ArrayBase::from_data_ptr(DataOwned::new(v), ptr).with_strides_dim(strides, dim)
Expand Down
6 changes: 3 additions & 3 deletions src/impl_views/constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

use std::ptr::NonNull;

use crate::dimension;
use crate::dimension::offset_from_low_addr_ptr_to_logical_ptr;
use crate::dimension::{self, CanIndexCheckMode};
use crate::error::ShapeError;
use crate::extension::nonnull::nonnull_debug_checked_from_ptr;
use crate::imp_prelude::*;
Expand Down Expand Up @@ -54,7 +54,7 @@ where D: Dimension
fn from_shape_impl(shape: StrideShape<D>, xs: &'a [A]) -> Result<Self, ShapeError>
{
let dim = shape.dim;
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::ReadOnly)?;
let strides = shape.strides.strides_for_dim(&dim);
unsafe {
Ok(Self::new_(
Expand Down Expand Up @@ -157,7 +157,7 @@ where D: Dimension
fn from_shape_impl(shape: StrideShape<D>, xs: &'a mut [A]) -> Result<Self, ShapeError>
{
let dim = shape.dim;
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides)?;
dimension::can_index_slice_with_strides(xs, &dim, &shape.strides, CanIndexCheckMode::OwnedMutable)?;
let strides = shape.strides.strides_for_dim(&dim);
unsafe {
Ok(Self::new_(
Expand Down
17 changes: 17 additions & 0 deletions tests/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use defmac::defmac;
use itertools::{zip, Itertools};
use ndarray::indices;
use ndarray::prelude::*;
use ndarray::ErrorKind;
use ndarray::{arr3, rcarr2};
use ndarray::{Slice, SliceInfo, SliceInfoElem};
use num_complex::Complex;
Expand Down Expand Up @@ -2060,6 +2061,22 @@ fn test_view_from_shape()
assert_eq!(a, answer);
}

#[test]
fn test_view_from_shape_allow_overlap()
{
let data = [0, 1, 2];
let view = ArrayView::from_shape((2, 3).strides((0, 1)), &data).unwrap();
assert_eq!(view, aview2(&[data; 2]));
}

#[test]
fn test_view_mut_from_shape_deny_overlap()
{
let mut data = [0, 1, 2];
let result = ArrayViewMut::from_shape((2, 3).strides((0, 1)), &mut data);
assert_matches!(result.map_err(|e| e.kind()), Err(ErrorKind::Unsupported));
}

#[test]
fn test_contiguous()
{
Expand Down

0 comments on commit 516a504

Please sign in to comment.