diff --git a/src/iterators/into_iter.rs b/src/iterators/into_iter.rs new file mode 100644 index 000000000..cfa48299a --- /dev/null +++ b/src/iterators/into_iter.rs @@ -0,0 +1,136 @@ +// Copyright 2020-2021 bluss and ndarray developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +use std::mem; +use std::ptr::NonNull; + +use crate::imp_prelude::*; +use crate::OwnedRepr; + +use super::Baseiter; +use crate::impl_owned_array::drop_unreachable_raw; + + +/// By-value iterator for an array +pub struct IntoIter +where + D: Dimension, +{ + array_data: OwnedRepr, + inner: Baseiter, + data_len: usize, + /// first memory address of an array element + array_head_ptr: NonNull, + // if true, the array owns elements that are not reachable by indexing + // through all the indices of the dimension. + has_unreachable_elements: bool, +} + +impl IntoIter +where + D: Dimension, +{ + /// Create a new by-value iterator that consumes `array` + pub(crate) fn new(mut array: Array) -> Self { + unsafe { + let array_head_ptr = array.ptr; + let ptr = array.as_mut_ptr(); + let mut array_data = array.data; + let data_len = array_data.release_all_elements(); + debug_assert!(data_len >= array.dim.size()); + let has_unreachable_elements = array.dim.size() != data_len; + let inner = Baseiter::new(ptr, array.dim, array.strides); + + IntoIter { + array_data, + inner, + data_len, + array_head_ptr, + has_unreachable_elements, + } + } + } +} + +impl Iterator for IntoIter { + type Item = A; + + #[inline] + fn next(&mut self) -> Option { + self.inner.next().map(|p| unsafe { p.read() }) + } + + fn size_hint(&self) -> (usize, Option) { + self.inner.size_hint() + } +} + +impl ExactSizeIterator for IntoIter { + fn len(&self) -> usize { self.inner.len() } +} + +impl Drop for IntoIter +where + D: Dimension +{ + fn drop(&mut self) { + if !self.has_unreachable_elements || mem::size_of::() == 0 || !mem::needs_drop::() { + return; + } + + // iterate til the end + while let Some(_) = self.next() { } + + unsafe { + let data_ptr = self.array_data.as_ptr_mut(); + let view = RawArrayViewMut::new(self.array_head_ptr, self.inner.dim.clone(), + self.inner.strides.clone()); + debug_assert!(self.inner.dim.size() < self.data_len, "data_len {} and dim size {}", + self.data_len, self.inner.dim.size()); + drop_unreachable_raw(view, data_ptr, self.data_len); + } + } +} + +impl IntoIterator for Array +where + D: Dimension +{ + type Item = A; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self) + } +} + +impl IntoIterator for ArcArray +where + D: Dimension, + A: Clone, +{ + type Item = A; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.into_owned()) + } +} + +impl IntoIterator for CowArray<'_, A, D> +where + D: Dimension, + A: Clone, +{ + type Item = A; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + IntoIter::new(self.into_owned()) + } +} diff --git a/src/iterators/mod.rs b/src/iterators/mod.rs index 595f0897d..bb618e5be 100644 --- a/src/iterators/mod.rs +++ b/src/iterators/mod.rs @@ -9,6 +9,7 @@ #[macro_use] mod macros; mod chunks; +mod into_iter; pub mod iter; mod lanes; mod windows; @@ -26,6 +27,7 @@ use super::{Dimension, Ix, Ixs}; pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut}; pub use self::lanes::{Lanes, LanesMut}; pub use self::windows::Windows; +pub use self::into_iter::IntoIter; use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut}; @@ -1465,6 +1467,7 @@ unsafe impl TrustedIterator for ::std::ops::Range {} // FIXME: These indices iter are dubious -- size needs to be checked up front. unsafe impl TrustedIterator for IndicesIter where D: Dimension {} unsafe impl TrustedIterator for IndicesIterF where D: Dimension {} +unsafe impl TrustedIterator for IntoIter where D: Dimension {} /// Like Iterator::collect, but only for trusted length iterators pub fn to_vec(iter: I) -> Vec diff --git a/tests/iterators.rs b/tests/iterators.rs index 4e4bbc666..a7c915389 100644 --- a/tests/iterators.rs +++ b/tests/iterators.rs @@ -8,7 +8,9 @@ use ndarray::prelude::*; use ndarray::{arr3, aview1, indices, s, Axis, Slice, Zip}; -use itertools::{assert_equal, enumerate}; +use itertools::assert_equal; +use itertools::enumerate; +use std::cell::Cell; macro_rules! assert_panics { ($body:expr) => { @@ -892,3 +894,100 @@ fn test_rfold() { ); } } + +#[test] +fn test_into_iter() { + let a = Array1::from(vec![1, 2, 3, 4]); + let v = a.into_iter().collect::>(); + assert_eq!(v, [1, 2, 3, 4]); +} + +#[test] +fn test_into_iter_2d() { + let a = Array1::from(vec![1, 2, 3, 4]).into_shape((2, 2)).unwrap(); + let v = a.into_iter().collect::>(); + assert_eq!(v, [1, 2, 3, 4]); + + let a = Array1::from(vec![1, 2, 3, 4]).into_shape((2, 2)).unwrap().reversed_axes(); + let v = a.into_iter().collect::>(); + assert_eq!(v, [1, 3, 2, 4]); +} + +#[test] +fn test_into_iter_sliced() { + /* + let mut a = Array1::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + .into_shape((3, 4)).unwrap(); + a.slice_axis_inplace(Axis(1), Slice::from(0..1)); + assert_eq!(a, arr2(&[[1], [5], [9]])); + let v = a.into_iter().collect::>(); + assert_eq!(v, [1, 5, 9]); + */ + + let (m, n) = (4, 5); + let drops = Cell::new(0); + + for i in 0..m - 1 { + for j in 0..n - 1 { + for i2 in i + 1 .. m { + for j2 in j + 1 .. n { + for invert in 0..3 { + drops.set(0); + let i = i as isize; + let j = j as isize; + let i2 = i2 as isize; + let j2 = j2 as isize; + let mut a = Array1::from_iter(0..(m * n) as i32) + .mapv(|v| DropCount::new(v, &drops)) + .into_shape((m, n)).unwrap(); + a.slice_collapse(s![i..i2, j..j2]); + if invert < a.ndim() { + a.invert_axis(Axis(invert)); + } + //assert_eq!(a, arr2(&[[1, 2], [5, 6]])); + println!("{:?}, {:?}", i..i2, j..j2); + println!("{:?}", a); + let answer = a.iter().cloned().collect::>(); + let v = a.into_iter().collect::>(); + assert_eq!(v, answer); + + assert_eq!(drops.get(), m * n - v.len()); + drop(v); + assert_eq!(drops.get(), m * n); + } + } + } + } + } +} + +/// Helper struct that counts its drops Asserts that it's not dropped twice. Also global number of +/// drops is counted in the cell. +/// +/// Compares equal by its "represented value". +#[derive(Clone, Debug)] +struct DropCount<'a> { + value: i32, + my_drops: usize, + drops: &'a Cell +} + +impl PartialEq for DropCount<'_> { + fn eq(&self, other: &Self) -> bool { + self.value == other.value + } +} + +impl<'a> DropCount<'a> { + fn new(value: i32, drops: &'a Cell) -> Self { + DropCount { value, my_drops: 0, drops } + } +} + +impl Drop for DropCount<'_> { + fn drop(&mut self) { + assert_eq!(self.my_drops, 0); + self.my_drops += 1; + self.drops.set(self.drops.get() + 1); + } +}