Skip to content

Commit

Permalink
Merge pull request #1310 from rust-ndarray/better-into-shape
Browse files Browse the repository at this point in the history
Better shape: Deprecate reshape, into_shape
  • Loading branch information
bluss authored Mar 10, 2024
2 parents cd0a956 + d32248d commit ba8f45c
Show file tree
Hide file tree
Showing 34 changed files with 456 additions and 199 deletions.
10 changes: 5 additions & 5 deletions README-quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ fn main() {
println!("a shape {:?}", &a.shape());
println!("b shape {:?}", &b.shape());

let b = b.into_shape((4,1)).unwrap(); // reshape b to shape [4, 1]
let b = b.into_shape_with_order((4,1)).unwrap(); // reshape b to shape [4, 1]
println!("b shape after reshape {:?}", &b.shape());

println!("{}", a.dot(&b)); // [1, 4] x [4, 1] -> [1, 1]
Expand Down Expand Up @@ -295,7 +295,7 @@ row: [[100, 101, 102],
## Shape Manipulation

### Changing the shape of an array
The shape of an array can be changed with `into_shape` method.
The shape of an array can be changed with the `into_shape_with_order` or `to_shape` method.

````rust
use ndarray::prelude::*;
Expand All @@ -319,7 +319,7 @@ fn main() {
let b = Array::from_iter(a.iter());
println!("b = \n{:?}\n", b);

let c = b.into_shape([6, 2]).unwrap(); // consume b and generate c with new shape
let c = b.into_shape_with_order([6, 2]).unwrap(); // consume b and generate c with new shape
println!("c = \n{:?}", c);
}
````
Expand Down Expand Up @@ -459,7 +459,7 @@ use ndarray::{Array, Axis};

fn main() {

let mut a = Array::range(0., 12., 1.).into_shape([3 ,4]).unwrap();
let mut a = Array::range(0., 12., 1.).into_shape_with_order([3 ,4]).unwrap();
println!("a = \n{}\n", a);

{
Expand Down Expand Up @@ -519,7 +519,7 @@ use ndarray::Array;

fn main() {

let mut a = Array::range(0., 4., 1.).into_shape([2 ,2]).unwrap();
let mut a = Array::range(0., 4., 1.).into_shape_with_order([2 ,2]).unwrap();
let b = a.clone();

println!("a = \n{}\n", a);
Expand Down
2 changes: 1 addition & 1 deletion benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,7 +915,7 @@ const MEAN_SUM_N: usize = 127;
fn range_mat(m: Ix, n: Ix) -> Array2<f32> {
assert!(m * n != 0);
Array::linspace(0., (m * n - 1) as f32, m * n)
.into_shape((m, n))
.into_shape_with_order((m, n))
.unwrap()
}

Expand Down
4 changes: 2 additions & 2 deletions benches/construct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ fn zeros_f64(bench: &mut Bencher) {

#[bench]
fn map_regular(bench: &mut test::Bencher) {
let a = Array::linspace(0., 127., 128).into_shape((8, 16)).unwrap();
let a = Array::linspace(0., 127., 128).into_shape_with_order((8, 16)).unwrap();
bench.iter(|| a.map(|&x| 2. * x));
}

#[bench]
fn map_stride(bench: &mut test::Bencher) {
let a = Array::linspace(0., 127., 256).into_shape((8, 32)).unwrap();
let a = Array::linspace(0., 127., 256).into_shape_with_order((8, 32)).unwrap();
let av = a.slice(s![.., ..;2]);
bench.iter(|| av.map(|&x| 2. * x));
}
14 changes: 7 additions & 7 deletions benches/higher-order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ const Y: usize = 16;

#[bench]
fn map_regular(bench: &mut Bencher) {
let a = Array::linspace(0., 127., N).into_shape((X, Y)).unwrap();
let a = Array::linspace(0., 127., N).into_shape_with_order((X, Y)).unwrap();
bench.iter(|| a.map(|&x| 2. * x));
}

Expand All @@ -28,7 +28,7 @@ pub fn double_array(mut a: ArrayViewMut2<'_, f64>) {
#[bench]
fn map_stride_double_f64(bench: &mut Bencher) {
let mut a = Array::linspace(0., 127., N * 2)
.into_shape([X, Y * 2])
.into_shape_with_order([X, Y * 2])
.unwrap();
let mut av = a.slice_mut(s![.., ..;2]);
bench.iter(|| {
Expand All @@ -39,7 +39,7 @@ fn map_stride_double_f64(bench: &mut Bencher) {
#[bench]
fn map_stride_f64(bench: &mut Bencher) {
let a = Array::linspace(0., 127., N * 2)
.into_shape([X, Y * 2])
.into_shape_with_order([X, Y * 2])
.unwrap();
let av = a.slice(s![.., ..;2]);
bench.iter(|| av.map(|&x| 2. * x));
Expand All @@ -48,7 +48,7 @@ fn map_stride_f64(bench: &mut Bencher) {
#[bench]
fn map_stride_u32(bench: &mut Bencher) {
let a = Array::linspace(0., 127., N * 2)
.into_shape([X, Y * 2])
.into_shape_with_order([X, Y * 2])
.unwrap();
let b = a.mapv(|x| x as u32);
let av = b.slice(s![.., ..;2]);
Expand All @@ -58,7 +58,7 @@ fn map_stride_u32(bench: &mut Bencher) {
#[bench]
fn fold_axis(bench: &mut Bencher) {
let a = Array::linspace(0., 127., N * 2)
.into_shape([X, Y * 2])
.into_shape_with_order([X, Y * 2])
.unwrap();
bench.iter(|| a.fold_axis(Axis(0), 0., |&acc, &elt| acc + elt));
}
Expand All @@ -69,15 +69,15 @@ const MASZ: usize = MA * MA;
#[bench]
fn map_axis_0(bench: &mut Bencher) {
let a = Array::from_iter(0..MASZ as i32)
.into_shape([MA, MA])
.into_shape_with_order([MA, MA])
.unwrap();
bench.iter(|| a.map_axis(Axis(0), black_box));
}

#[bench]
fn map_axis_1(bench: &mut Bencher) {
let a = Array::from_iter(0..MASZ as i32)
.into_shape([MA, MA])
.into_shape_with_order([MA, MA])
.unwrap();
bench.iter(|| a.map_axis(Axis(1), black_box));
}
10 changes: 5 additions & 5 deletions benches/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,29 +46,29 @@ fn iter_sum_2d_transpose(bench: &mut Bencher) {

#[bench]
fn iter_filter_sum_2d_u32(bench: &mut Bencher) {
let a = Array::linspace(0., 1., 256).into_shape((16, 16)).unwrap();
let a = Array::linspace(0., 1., 256).into_shape_with_order((16, 16)).unwrap();
let b = a.mapv(|x| (x * 100.) as u32);
bench.iter(|| b.iter().filter(|&&x| x < 75).sum::<u32>());
}

#[bench]
fn iter_filter_sum_2d_f32(bench: &mut Bencher) {
let a = Array::linspace(0., 1., 256).into_shape((16, 16)).unwrap();
let a = Array::linspace(0., 1., 256).into_shape_with_order((16, 16)).unwrap();
let b = a * 100.;
bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::<f32>());
}

#[bench]
fn iter_filter_sum_2d_stride_u32(bench: &mut Bencher) {
let a = Array::linspace(0., 1., 256).into_shape((16, 16)).unwrap();
let a = Array::linspace(0., 1., 256).into_shape_with_order((16, 16)).unwrap();
let b = a.mapv(|x| (x * 100.) as u32);
let b = b.slice(s![.., ..;2]);
bench.iter(|| b.iter().filter(|&&x| x < 75).sum::<u32>());
}

#[bench]
fn iter_filter_sum_2d_stride_f32(bench: &mut Bencher) {
let a = Array::linspace(0., 1., 256).into_shape((16, 16)).unwrap();
let a = Array::linspace(0., 1., 256).into_shape_with_order((16, 16)).unwrap();
let b = a * 100.;
let b = b.slice(s![.., ..;2]);
bench.iter(|| b.iter().filter(|&&x| x < 75.).sum::<f32>());
Expand Down Expand Up @@ -321,7 +321,7 @@ fn indexed_iter_3d_dyn(bench: &mut Bencher) {
for ((i, j, k), elt) in a.indexed_iter_mut() {
*elt = (i + 100 * j + 10000 * k) as _;
}
let a = a.into_shape(&[ISZ; 3][..]).unwrap();
let a = a.into_shape_with_order(&[ISZ; 3][..]).unwrap();

bench.iter(|| {
for (i, &_elt) in a.indexed_iter() {
Expand Down
2 changes: 1 addition & 1 deletion benches/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const Y: usize = 16;
#[bench]
fn clip(bench: &mut Bencher) {
let mut a = Array::linspace(0., 127., N * 2)
.into_shape([X, Y * 2])
.into_shape_with_order([X, Y * 2])
.unwrap();
let min = 2.;
let max = 5.;
Expand Down
2 changes: 1 addition & 1 deletion examples/axis_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ fn main() {
}
regularize(&mut b).unwrap();

let mut b = b.into_shape(a.len()).unwrap();
let mut b = b.into_shape_with_order(a.len()).unwrap();
regularize(&mut b).unwrap();

b.invert_axis(Axis(0));
Expand Down
2 changes: 1 addition & 1 deletion examples/life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ fn parse(x: &[u8]) -> Board {
_ => None,
}));

let a = a.into_shape((N, N)).unwrap();
let a = a.into_shape_with_order((N, N)).unwrap();
map.slice_mut(s![1..-1, 1..-1]).assign(&a);
map
}
Expand Down
2 changes: 1 addition & 1 deletion examples/sort-axis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ where

#[cfg(feature = "std")]
fn main() {
let a = Array::linspace(0., 63., 64).into_shape((8, 8)).unwrap();
let a = Array::linspace(0., 63., 64).into_shape_with_order((8, 8)).unwrap();
let strings = a.map(|x| x.to_string());

let perm = a.sort_axis_by(Axis(1), |i, j| a[[i, 0]] > a[[j, 0]]);
Expand Down
4 changes: 2 additions & 2 deletions src/free_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ pub const fn aview0<A>(x: &A) -> ArrayView0<'_, A> {
/// let data = [1.0; 1024];
///
/// // Create a 2D array view from borrowed data
/// let a2d = aview1(&data).into_shape((32, 32)).unwrap();
/// let a2d = aview1(&data).into_shape_with_order((32, 32)).unwrap();
///
/// assert_eq!(a2d.sum(), 1024.0);
///
Expand Down Expand Up @@ -174,7 +174,7 @@ pub const fn aview2<A, const N: usize>(xs: &[[A; N]]) -> ArrayView2<'_, A> {
/// // Create an array view over some data, then slice it and modify it.
/// let mut data = [0; 1024];
/// {
/// let mut a = aview_mut1(&mut data).into_shape((32, 32)).unwrap();
/// let mut a = aview_mut1(&mut data).into_shape_with_order((32, 32)).unwrap();
/// a.slice_mut(s![.., ..;3]).fill(5);
/// }
/// assert_eq!(&data[..10], [5, 0, 0, 5, 0, 0, 5, 0, 0, 5]);
Expand Down
12 changes: 6 additions & 6 deletions src/impl_constructors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ where
A: Clone,
Sh: ShapeBuilder<Dim = D>,
{
let shape = shape.into_shape();
let shape = shape.into_shape_with_order();
let size = size_of_shape_checked_unwrap!(&shape.dim);
let v = vec![elem; size];
unsafe { Self::from_shape_vec_unchecked(shape, v) }
Expand Down Expand Up @@ -383,7 +383,7 @@ where
Sh: ShapeBuilder<Dim = D>,
F: FnMut() -> A,
{
let shape = shape.into_shape();
let shape = shape.into_shape_with_order();
let len = size_of_shape_checked_unwrap!(&shape.dim);
let v = to_vec_mapped(0..len, move |_| f());
unsafe { Self::from_shape_vec_unchecked(shape, v) }
Expand Down Expand Up @@ -414,7 +414,7 @@ where
Sh: ShapeBuilder<Dim = D>,
F: FnMut(D::Pattern) -> A,
{
let shape = shape.into_shape();
let shape = shape.into_shape_with_order();
let _ = size_of_shape_checked_unwrap!(&shape.dim);
if shape.is_c() {
let v = to_vec_mapped(indices(shape.dim.clone()).into_iter(), f);
Expand Down Expand Up @@ -591,7 +591,7 @@ where
Sh: ShapeBuilder<Dim = D>,
{
unsafe {
let shape = shape.into_shape();
let shape = shape.into_shape_with_order();
let size = size_of_shape_checked_unwrap!(&shape.dim);
let mut v = Vec::with_capacity(size);
v.set_len(size);
Expand Down Expand Up @@ -664,7 +664,7 @@ where
A: Copy,
Sh: ShapeBuilder<Dim = D>,
{
let shape = shape.into_shape();
let shape = shape.into_shape_with_order();
let size = size_of_shape_checked_unwrap!(&shape.dim);
let mut v = Vec::with_capacity(size);
v.set_len(size);
Expand All @@ -687,7 +687,7 @@ where
Sh: ShapeBuilder<Dim = D>,
{
unsafe {
let shape = shape.into_shape();
let shape = shape.into_shape_with_order();
let size = size_of_shape_checked_unwrap!(&shape.dim);
let mut v = Vec::with_capacity(size);
v.set_len(size);
Expand Down
Loading

0 comments on commit ba8f45c

Please sign in to comment.