Skip to content

Commit

Permalink
intoiter: Implement by-value iterator for owned arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
bluss committed Apr 22, 2021
1 parent 7f86eca commit d5d0892
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 1 deletion.
136 changes: 136 additions & 0 deletions src/iterators/into_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// Copyright 2020-2021 bluss and ndarray developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.

use std::mem;
use std::ptr::NonNull;

use crate::imp_prelude::*;
use crate::OwnedRepr;

use super::Baseiter;
use crate::impl_owned_array::drop_unreachable_raw;


/// By-value iterator for an array
pub struct IntoIter<A, D>
where
D: Dimension,
{
array_data: OwnedRepr<A>,
inner: Baseiter<A, D>,
data_len: usize,
/// first memory address of an array element
array_head_ptr: NonNull<A>,
// if true, the array owns elements that are not reachable by indexing
// through all the indices of the dimension.
has_unreachable_elements: bool,
}

impl<A, D> IntoIter<A, D>
where
D: Dimension,
{
/// Create a new by-value iterator that consumes `array`
pub(crate) fn new(mut array: Array<A, D>) -> Self {
unsafe {
let array_head_ptr = array.ptr;
let ptr = array.as_mut_ptr();
let mut array_data = array.data;
let data_len = array_data.release_all_elements();
debug_assert!(data_len >= array.dim.size());
let has_unreachable_elements = array.dim.size() != data_len;
let inner = Baseiter::new(ptr, array.dim, array.strides);

IntoIter {
array_data,
inner,
data_len,
array_head_ptr,
has_unreachable_elements,
}
}
}
}

impl<A, D: Dimension> Iterator for IntoIter<A, D> {
type Item = A;

#[inline]
fn next(&mut self) -> Option<A> {
self.inner.next().map(|p| unsafe { p.read() })
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}

impl<A, D: Dimension> ExactSizeIterator for IntoIter<A, D> {
fn len(&self) -> usize { self.inner.len() }
}

impl<A, D> Drop for IntoIter<A, D>
where
D: Dimension
{
fn drop(&mut self) {
if !self.has_unreachable_elements || mem::size_of::<A>() == 0 || !mem::needs_drop::<A>() {
return;
}

// iterate til the end
while let Some(_) = self.next() { }

unsafe {
let data_ptr = self.array_data.as_ptr_mut();
let view = RawArrayViewMut::new(self.array_head_ptr, self.inner.dim.clone(),
self.inner.strides.clone());
debug_assert!(self.inner.dim.size() < self.data_len, "data_len {} and dim size {}",
self.data_len, self.inner.dim.size());
drop_unreachable_raw(view, data_ptr, self.data_len);
}
}
}

impl<A, D> IntoIterator for Array<A, D>
where
D: Dimension
{
type Item = A;
type IntoIter = IntoIter<A, D>;

fn into_iter(self) -> Self::IntoIter {
IntoIter::new(self)
}
}

impl<A, D> IntoIterator for ArcArray<A, D>
where
D: Dimension,
A: Clone,
{
type Item = A;
type IntoIter = IntoIter<A, D>;

fn into_iter(self) -> Self::IntoIter {
IntoIter::new(self.into_owned())
}
}

impl<A, D> IntoIterator for CowArray<'_, A, D>
where
D: Dimension,
A: Clone,
{
type Item = A;
type IntoIter = IntoIter<A, D>;

fn into_iter(self) -> Self::IntoIter {
IntoIter::new(self.into_owned())
}
}
3 changes: 3 additions & 0 deletions src/iterators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#[macro_use]
mod macros;
mod chunks;
mod into_iter;
pub mod iter;
mod lanes;
mod windows;
Expand All @@ -26,6 +27,7 @@ use super::{Dimension, Ix, Ixs};
pub use self::chunks::{ExactChunks, ExactChunksIter, ExactChunksIterMut, ExactChunksMut};
pub use self::lanes::{Lanes, LanesMut};
pub use self::windows::Windows;
pub use self::into_iter::IntoIter;

use std::slice::{self, Iter as SliceIter, IterMut as SliceIterMut};

Expand Down Expand Up @@ -1465,6 +1467,7 @@ unsafe impl TrustedIterator for ::std::ops::Range<usize> {}
// FIXME: These indices iter are dubious -- size needs to be checked up front.
unsafe impl<D> TrustedIterator for IndicesIter<D> where D: Dimension {}
unsafe impl<D> TrustedIterator for IndicesIterF<D> where D: Dimension {}
unsafe impl<A, D> TrustedIterator for IntoIter<A, D> where D: Dimension {}

/// Like Iterator::collect, but only for trusted length iterators
pub fn to_vec<I>(iter: I) -> Vec<I::Item>
Expand Down
101 changes: 100 additions & 1 deletion tests/iterators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
use ndarray::prelude::*;
use ndarray::{arr3, aview1, indices, s, Axis, Slice, Zip};

use itertools::{assert_equal, enumerate};
use itertools::assert_equal;
use itertools::enumerate;
use std::cell::Cell;

macro_rules! assert_panics {
($body:expr) => {
Expand Down Expand Up @@ -892,3 +894,100 @@ fn test_rfold() {
);
}
}

#[test]
fn test_into_iter() {
let a = Array1::from(vec![1, 2, 3, 4]);
let v = a.into_iter().collect::<Vec<_>>();
assert_eq!(v, [1, 2, 3, 4]);
}

#[test]
fn test_into_iter_2d() {
let a = Array1::from(vec![1, 2, 3, 4]).into_shape((2, 2)).unwrap();
let v = a.into_iter().collect::<Vec<_>>();
assert_eq!(v, [1, 2, 3, 4]);

let a = Array1::from(vec![1, 2, 3, 4]).into_shape((2, 2)).unwrap().reversed_axes();
let v = a.into_iter().collect::<Vec<_>>();
assert_eq!(v, [1, 3, 2, 4]);
}

#[test]
fn test_into_iter_sliced() {
/*
let mut a = Array1::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
.into_shape((3, 4)).unwrap();
a.slice_axis_inplace(Axis(1), Slice::from(0..1));
assert_eq!(a, arr2(&[[1], [5], [9]]));
let v = a.into_iter().collect::<Vec<_>>();
assert_eq!(v, [1, 5, 9]);
*/

let (m, n) = (4, 5);
let drops = Cell::new(0);

for i in 0..m - 1 {
for j in 0..n - 1 {
for i2 in i + 1 .. m {
for j2 in j + 1 .. n {
for invert in 0..3 {
drops.set(0);
let i = i as isize;
let j = j as isize;
let i2 = i2 as isize;
let j2 = j2 as isize;
let mut a = Array1::from_iter(0..(m * n) as i32)
.mapv(|v| DropCount::new(v, &drops))
.into_shape((m, n)).unwrap();
a.slice_collapse(s![i..i2, j..j2]);
if invert < a.ndim() {
a.invert_axis(Axis(invert));
}
//assert_eq!(a, arr2(&[[1, 2], [5, 6]]));
println!("{:?}, {:?}", i..i2, j..j2);
println!("{:?}", a);
let answer = a.iter().cloned().collect::<Vec<_>>();
let v = a.into_iter().collect::<Vec<_>>();
assert_eq!(v, answer);

assert_eq!(drops.get(), m * n - v.len());
drop(v);
assert_eq!(drops.get(), m * n);
}
}
}
}
}
}

/// Helper struct that counts its drops Asserts that it's not dropped twice. Also global number of
/// drops is counted in the cell.
///
/// Compares equal by its "represented value".
#[derive(Clone, Debug)]
struct DropCount<'a> {
value: i32,
my_drops: usize,
drops: &'a Cell<usize>
}

impl PartialEq for DropCount<'_> {
fn eq(&self, other: &Self) -> bool {
self.value == other.value
}
}

impl<'a> DropCount<'a> {
fn new(value: i32, drops: &'a Cell<usize>) -> Self {
DropCount { value, my_drops: 0, drops }
}
}

impl Drop for DropCount<'_> {
fn drop(&mut self) {
assert_eq!(self.my_drops, 0);
self.my_drops += 1;
self.drops.set(self.drops.get() + 1);
}
}

0 comments on commit d5d0892

Please sign in to comment.