Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft of potential masked array implementation. #849

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1608,6 +1608,7 @@ mod impl_raw_views;

// Copy-on-write array methods
mod impl_cow;
pub mod ma;

/// A contiguous array shape of n dimensions.
///
Expand Down
218 changes: 218 additions & 0 deletions src/ma/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
use std::cmp::PartialEq;
use std::marker::PhantomData;
use std::ops::{Add, Index};
use crate::{ArrayBase, Array1, RawData, Data, DataOwned, Dimension, NdIndex, Array, DataMut};
use crate::iter::IndexedIter;

/// Enum that represents a value that can potentially be masked.
/// We could potentially use `Option<T>` for that, but that produces
/// weird `Option<Option<T>>` return types in iterators.
/// This type can be converted to `Option<T>` using `into` method.
/// There is also a `PartialEq` implementation just to be able to
/// use it in `assert_eq!` statements.
#[derive(Clone, Copy, Debug, Eq)]
pub enum Masked<T> {
Value(T),
Empty,
}

impl<T> Masked<&T> {
fn cloned(&self) -> Masked<T>
where
T: Clone
{
match self {
Masked::Value(v) => Masked::Value((*v).clone()),
Masked::Empty => Masked::Empty,
}
}
}

impl<T> PartialEq for Masked<T>
where
T: PartialEq
{
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Masked::Value(v1), Masked::Value(v2)) => v1.eq(v2),
(Masked::Empty, Masked::Empty) => true,
_ => false,
}
}
}

impl<T> From<Masked<T>> for Option<T> {
fn from(other: Masked<T>) -> Option<T> {
match other {
Masked::Value(v) => Some(v),
Masked::Empty => None,
}
}
}

/// Every struct that can be used as a mask should implement this trait.
/// It has two generic parameters:
/// A - type of the values to be masked
/// D - dimension of the mask
/// The trait is implemented in such a way so that it could be implemented
/// by different types, not just variations of `ArrayBase`. For example,
/// we can implement a mask as a whitelist/blacklist of indices or as a
/// struct which treats some value or range of values as a mask.
pub trait Mask<A, D> {
/// Given an index of the element and a reference to it, return masked
/// version of the reference. Accepting a pair allows masking by index,
/// value or both.
fn mask_ref<'a, I: NdIndex<D>>(&self, pair: (I, &'a A)) -> Masked<&'a A>;

// Probably we will need two more methods to be able to mask by value and
// by mutable reference:

// fn mask<I: NdIndex<D>>(&self, pair: (I, A)) -> Masked<A>;
// fn mask_ref_mut<'a, I: NdIndex<D>>(&self, pair: (I, &'a mut A)) -> Masked<&'a mut A>;

fn mask_iter<'a, 'b: 'a, I>(&'b self, iter: I) -> MaskedIter<'a, A, Self, I, D>
where
I: Iterator<Item = (D::Pattern, &'a A)>,
D: Dimension,
D::Pattern: NdIndex<D>,
{
MaskedIter::new(self, iter)
}
}

/// Given two masks, generate their intersection. This may be required for any
/// binary operations with two masks.
pub trait JoinMask<A, D, M> : Mask<A, D>
where
M: Mask<A, D>
{
type Output: Mask<A, D>;

fn join(&self, other: &M) -> Self::Output;
}

pub struct MaskedIter<'a, A: 'a, M, I, D>
where
I: Iterator<Item = (D::Pattern, &'a A)>,
D: Dimension,
D::Pattern: NdIndex<D>,
M: ?Sized + Mask<A, D>
{
mask: &'a M,
iter: I,
_dim: PhantomData<D>,
}

impl<'a, A, M, I, D> MaskedIter<'a, A, M, I, D>
where
I: Iterator<Item = (D::Pattern, &'a A)>,
D: Dimension,
D::Pattern: NdIndex<D>,
M: ?Sized + Mask<A, D>
{
fn new(mask: &'a M, iter: I) -> MaskedIter<'a, A, M, I, D> {
MaskedIter { mask, iter, _dim: PhantomData }
}
}

impl<'a, A, M, I, D> Iterator for MaskedIter<'a, A, M, I, D>
where
I: Iterator<Item = (D::Pattern, &'a A)>,
D: Dimension,
D::Pattern: NdIndex<D>,
M: Mask<A, D>
{
type Item = Masked<&'a A>;

fn next(&mut self) -> Option<Self::Item> {
let nex_val = self.iter.next()?;
Some(self.mask.mask_ref(nex_val))
}
}

/// First implementation of the mask as a bool array of the same shape.
impl<A, S, D> Mask<A, D> for ArrayBase<S, D>
where
D: Dimension,
S: Data<Elem = bool>,
{
fn mask_ref<'a, I: NdIndex<D>>(&self, pair: (I, &'a A)) -> Masked<&'a A> {
if *self.index(pair.0) { Masked::Value(pair.1) } else { Masked::Empty }
}
}

impl<A, S1, S2, D> JoinMask<A, D, ArrayBase<S1, D>> for ArrayBase<S2, D>
where
D: Dimension,
S1: Data<Elem = bool>,
S2: Data<Elem = bool>,
{
type Output = Array<bool, D>;

fn join(&self, other: &ArrayBase<S1, D>) -> Self::Output {
self & other
}
}

/// Base type for masked array. `S` and `D` types are exactly the ones
/// of `ArrayBase`, `M` is a mask type.
pub struct MaskedArrayBase<S, D, M>
where
S: RawData,
M: Mask<S::Elem, D>,
{
data: ArrayBase<S, D>,
mask: M,
}

impl<S, D, M> MaskedArrayBase<S, D, M>
where
S: RawData,
D: Dimension,
M: Mask<S::Elem, D>,
{
pub fn compressed(&self) -> Array1<S::Elem>
where
S::Elem: Clone,
S: Data,
D::Pattern: NdIndex<D>,
{
self.iter()
.filter_map(|mv: Masked<&S::Elem>| mv.cloned().into())
.collect()
}

pub fn iter(&self) -> MaskedIter<'_, S::Elem, M, IndexedIter<'_, S::Elem, D>, D>
where
S: Data,
D::Pattern: NdIndex<D>,
{
self.mask.mask_iter(self.data.indexed_iter())
}
}

impl<A, S1, S2, D, M> Add<MaskedArrayBase<S2, D, M>> for MaskedArrayBase<S1, D, M>
where
A: Clone + Add<A, Output = A>,
S1: DataOwned<Elem = A> + DataMut,
S2: Data<Elem = A>,
D: Dimension,
M: Mask<A, D> + JoinMask<A, D, M>,
{
type Output = MaskedArrayBase<S1, D, <M as JoinMask<A, D, M>>::Output>;

fn add(self, rhs: MaskedArrayBase<S2, D, M>) -> Self::Output {
MaskedArrayBase {
data: self.data + rhs.data,
mask: self.mask.join(&rhs.mask),
}
}
}

pub fn array<S, D, M>(data: ArrayBase<S, D>, mask: M) -> MaskedArrayBase<S, D, M>
where
S: RawData,
M: Mask<S::Elem, D>,
{
MaskedArrayBase { data, mask }
}
44 changes: 44 additions & 0 deletions tests/ma.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use ndarray::{array};
use ndarray::ma;

#[cfg(test)]
mod test_array_mask {
use super::*;

#[test]
fn test_iter() {
let data = array![1, 2, 3, 4];
let mask = array![true, false, true, false];
let arr = ma::array(data, mask);
let actual_vec: Vec<_> = arr.iter().collect();
let expected_vec = vec![
ma::Masked::Value(&1),
ma::Masked::Empty,
ma::Masked::Value(&3),
ma::Masked::Empty,
];
assert_eq!(actual_vec, expected_vec);
}

#[test]
fn test_compressed() {
let arr = ma::array(array![1, 2, 3, 4], array![true, true, false, false]);
let res = arr.compressed();
assert_eq!(res, array![1, 2]);
}

#[test]
fn test_add() {
let arr1 = ma::array(array![1, 2, 3, 4], array![true, false, true, false]);
let arr2 = ma::array(array![4, 3, 2, 1], array![true, false, false, false]);
let res = arr1 + arr2;
let actual_vec: Vec<_> = res.iter().collect();
let expected_vec = vec![
ma::Masked::Value(&5),
ma::Masked::Empty,
ma::Masked::Empty,
ma::Masked::Empty,
];
assert_eq!(actual_vec, expected_vec);
}
}