Skip to content

Commit

Permalink
Add const-generic IndexedParallelIterator::arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
cuviper committed Dec 13, 2023
1 parent d1b18e6 commit 5e16c5e
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 0 deletions.
221 changes: 221 additions & 0 deletions src/iter/arrays.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
use super::plumbing::*;
use super::*;

/// `Arrays` is an iterator that groups elements of an underlying iterator.
///
/// This struct is created by the [`arrays()`] method on [`IndexedParallelIterator`]
///
/// [`arrays()`]: trait.IndexedParallelIterator.html#method.arrays
/// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
#[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
#[derive(Debug, Clone)]
pub struct Arrays<I, const N: usize>
where
I: IndexedParallelIterator,
{
iter: I,
}

impl<I, const N: usize> Arrays<I, N>
where
I: IndexedParallelIterator,
{
/// Creates a new `Arrays` iterator
pub(super) fn new(iter: I) -> Self {
Arrays { iter }
}
}

impl<I, const N: usize> ParallelIterator for Arrays<I, N>
where
I: IndexedParallelIterator,
{
type Item = [I::Item; N];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: Consumer<Self::Item>,
{
bridge(self, consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.len())
}
}

impl<I, const N: usize> IndexedParallelIterator for Arrays<I, N>
where
I: IndexedParallelIterator,
{
fn drive<C>(self, consumer: C) -> C::Result
where
C: Consumer<Self::Item>,
{
bridge(self, consumer)
}

fn len(&self) -> usize {
self.iter.len() / N
}

fn with_producer<CB>(self, callback: CB) -> CB::Output
where
CB: ProducerCallback<Self::Item>,
{
let len = self.iter.len();
return self.iter.with_producer(Callback { len, callback });

struct Callback<CB, const N: usize> {
len: usize,
callback: CB,
}

impl<T, CB, const N: usize> ProducerCallback<T> for Callback<CB, N>
where
CB: ProducerCallback<[T; N]>,
{
type Output = CB::Output;

fn callback<P>(self, base: P) -> CB::Output
where
P: Producer<Item = T>,
{
self.callback.callback(ArrayProducer {
len: self.len,
base,
})
}
}
}
}

struct ArrayProducer<P, const N: usize>
where
P: Producer,
{
len: usize,
base: P,
}

impl<P, const N: usize> Producer for ArrayProducer<P, N>
where
P: Producer,
{
type Item = [P::Item; N];
type IntoIter = ArraySeq<P, N>;

fn into_iter(self) -> Self::IntoIter {
// TODO: we're ignoring any remainder -- should we no-op consume it?
let remainder = self.len % N;
let len = self.len - remainder;
let inner = (len > 0).then(|| self.base.split_at(len).0);
ArraySeq { len, inner }
}

fn split_at(self, index: usize) -> (Self, Self) {
let elem_index = index * N;
let (left, right) = self.base.split_at(elem_index);
(
ArrayProducer {
len: elem_index,
base: left,
},
ArrayProducer {
len: self.len - elem_index,
base: right,
},
)
}

fn min_len(&self) -> usize {
self.base.min_len() / N
}

fn max_len(&self) -> usize {
self.base.max_len() / N
}
}

struct ArraySeq<P, const N: usize> {
len: usize,
inner: Option<P>,
}

impl<P, const N: usize> Iterator for ArraySeq<P, N>
where
P: Producer,
{
type Item = [P::Item; N];

fn next(&mut self) -> Option<Self::Item> {
let mut producer = self.inner.take()?;
debug_assert!(self.len > 0 && self.len % N == 0);
if self.len > N {
let (left, right) = producer.split_at(N);
producer = left;
self.inner = Some(right);
self.len -= N;
} else {
self.len = 0;
}
Some(collect_array(producer.into_iter()))
}

fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.len();
(len, Some(len))
}
}

impl<P, const N: usize> ExactSizeIterator for ArraySeq<P, N>
where
P: Producer,
{
#[inline]
fn len(&self) -> usize {
self.len / N
}
}

impl<P, const N: usize> DoubleEndedIterator for ArraySeq<P, N>
where
P: Producer,
{
fn next_back(&mut self) -> Option<Self::Item> {
let mut producer = self.inner.take()?;
debug_assert!(self.len > 0 && self.len % N == 0);
if self.len > N {
let (left, right) = producer.split_at(self.len - N);
producer = right;
self.inner = Some(left);
self.len -= N;
} else {
self.len = 0;
}
Some(collect_array(producer.into_iter()))
}
}

fn collect_array<T, const N: usize>(mut iter: impl ExactSizeIterator<Item = T>) -> [T; N] {
// TODO(MSRV-1.55): consider `[(); N].map(...)`
// TODO(MSRV-1.63): consider `std::array::from_fn`

use std::mem::MaybeUninit;

// TODO(MSRV): use `MaybeUninit::uninit_array` when/if it's stabilized.
// SAFETY: We can assume "init" when moving uninit wrappers inward.
let mut array: [MaybeUninit<T>; N] =
unsafe { MaybeUninit::<[MaybeUninit<T>; N]>::uninit().assume_init() };

debug_assert_eq!(iter.len(), N);
for i in 0..N {
let item = iter.next().expect("should have N items");
array[i] = MaybeUninit::new(item);
}
debug_assert!(iter.next().is_none());

// TODO(MSRV): use `MaybeUninit::array_assume_init` when/if it's stabilized.
// SAFETY: We've initialized all N items in the array, so we can cast and "move" it.
unsafe { (&array as *const [MaybeUninit<T>; N] as *const [T; N]).read() }
}
28 changes: 28 additions & 0 deletions src/iter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ mod test;
// e.g. `find::find()`, are always used **prefixed**, so that they
// can be readily distinguished.

mod arrays;
mod chain;
mod chunks;
mod cloned;
Expand Down Expand Up @@ -159,6 +160,7 @@ mod zip;
mod zip_eq;

pub use self::{
arrays::Arrays,
chain::Chain,
chunks::Chunks,
cloned::Cloned,
Expand Down Expand Up @@ -2544,6 +2546,32 @@ pub trait IndexedParallelIterator: ParallelIterator {
InterleaveShortest::new(self, other.into_par_iter())
}

/// Splits an iterator up into fixed-size arrays.
///
/// Returns an iterator that returns arrays with the given number of elements.
/// If the number of elements in the iterator is not divisible by `N`,
/// the remaining items are ignored.
///
/// See also [`par_array_chunks()`] and [`par_array_chunks_mut()`] for similar
/// behavior on slices, although they yield array references instead.
///
/// [`par_array_chunks()`]: ../slice/trait.ParallelSlice.html#method.par_array_chunks
/// [`par_array_chunks_mut()`]: ../slice/trait.ParallelSliceMut.html#method.par_array_chunks_mut
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let a = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
/// let r: Vec<[i32; 3]> = a.into_par_iter().arrays().collect();
/// assert_eq!(r, vec![[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
/// ```
#[track_caller]
fn arrays<const N: usize>(self) -> Arrays<Self, N> {
assert!(N != 0, "array length must not be zero");
Arrays::new(self)
}

/// Splits an iterator up into fixed-size chunks.
///
/// Returns an iterator that returns `Vec`s of the given number of elements.
Expand Down
1 change: 1 addition & 0 deletions tests/clones.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ fn clone_adaptors() {
check(v.par_iter().interleave_shortest(&v));
check(v.par_iter().intersperse(&None));
check(v.par_iter().chunks(3));
check(v.par_iter().arrays::<3>());
check(v.par_iter().map(|x| x));
check(v.par_iter().map_with(0, |_, x| x));
check(v.par_iter().map_init(|| 0, |_, x| x));
Expand Down
1 change: 1 addition & 0 deletions tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ fn debug_adaptors() {
check(v.par_iter().interleave_shortest(&v));
check(v.par_iter().intersperse(&-1));
check(v.par_iter().chunks(3));
check(v.par_iter().arrays::<3>());
check(v.par_iter().map(|x| x));
check(v.par_iter().map_with(0, |_, x| x));
check(v.par_iter().map_init(|| 0, |_, x| x));
Expand Down
23 changes: 23 additions & 0 deletions tests/producer_split_at.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,29 @@ fn chunks() {
check(&v, || s.par_iter().cloned().chunks(2));
}

#[test]
fn arrays() {
use std::convert::TryInto;
fn check_len<const N: usize>(s: &[i32]) {
let v: Vec<[_; N]> = s.chunks_exact(N).map(|c| c.try_into().unwrap()).collect();
check(&v, || s.par_iter().copied().arrays::<N>());
}

let s: Vec<_> = (0..10).collect();
check_len::<1>(&s);
check_len::<2>(&s);
check_len::<3>(&s);
check_len::<4>(&s);
check_len::<5>(&s);
check_len::<6>(&s);
check_len::<7>(&s);
check_len::<8>(&s);
check_len::<9>(&s);
check_len::<10>(&s);
check_len::<11>(&s);
check_len::<12>(&s);
}

#[test]
fn map() {
let v: Vec<_> = (0..10).collect();
Expand Down

0 comments on commit 5e16c5e

Please sign in to comment.