Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/179 repeat #674

Merged
merged 53 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
0ced160
First approach to implementation
lenablind Sep 1, 2020
f229825
Additional type restrictions
lenablind Sep 1, 2020
285fa1c
Restrictions for repeats
lenablind Sep 2, 2020
b79bf5e
Implementation of undistributed function
lenablind Sep 2, 2020
4c7f425
Check for repeats consisting only of integers as DNDarray
lenablind Sep 2, 2020
f7f23dd
First test cases for undistributed implementation
lenablind Sep 2, 2020
17cd911
Checks for 'repeats.dtype' DNDarray or np.ndarray
lenablind Sep 3, 2020
70170cd
Additional tests and first approach to distribution
lenablind Sep 3, 2020
2d8bda2
Query for dtype of np.ndarray and ht.DNDarray
lenablind Sep 3, 2020
dda37ca
Broken for distributed case. Adaption of split syntax
lenablind Sep 3, 2020
aa7a412
Removal of DNDarray as a possible dtype for 'repeats'
lenablind Sep 4, 2020
fdff989
Adaption of tests - Replaced assert_array_equal with all equivalent
lenablind Sep 4, 2020
8ccefbb
Broken - Adaption of split algorithm for repeats
lenablind Sep 7, 2020
b9713b6
Usage of resplit_
lenablind Sep 7, 2020
adbe489
Merge branch 'master' into features/179-repeat
lenablind Sep 8, 2020
0713d47
Use of ht.empty
lenablind Sep 8, 2020
328afc9
Restructured sanitation of a, replaced ht.flatten with torch.flatten …
lenablind Sep 8, 2020
c7aed40
Moved (test) functions into manipulations
lenablind Sep 9, 2020
e25e218
is_split adaption for repeats
lenablind Sep 9, 2020
5e1aef7
Additional tests for axis != None
lenablind Sep 9, 2020
d46eb65
Improvement of lshape handling
lenablind Sep 9, 2020
a80eab4
Algorithm simplification
lenablind Sep 9, 2020
4c57e09
DNDarray as valid dtype for repeats
lenablind Sep 10, 2020
6828035
Repeats distributed, axis = None
lenablind Sep 10, 2020
81bcb60
a distributed, repeats undistributed
lenablind Sep 11, 2020
210ea27
Code restructured
lenablind Sep 11, 2020
c4f1bb9
Code restructured
lenablind Sep 11, 2020
919f005
Broken if both repeats & a are distributed
lenablind Sep 11, 2020
bbb3552
Repeats distributed, axis != None, Debugging
lenablind Sep 14, 2020
ebb12c3
All working testcases, debugging
lenablind Sep 14, 2020
99830ce
Restructured code
lenablind Sep 14, 2020
1eed38c
Quickfix tricky test case
lenablind Sep 14, 2020
fe3e423
isDeleted quickfix, redefined is_split
lenablind Sep 15, 2020
a92468f
resplit_ in any case after reshape
lenablind Sep 16, 2020
3ad2a3a
New algorithm approach (if axis is None)
lenablind Sep 16, 2020
57d455e
Moved all sanitation fragments to the top
lenablind Sep 16, 2020
c3697de
Removed print statement s out of manipulations.py
lenablind Sep 16, 2020
8e178b6
Moved empty case (no data) to bottom
lenablind Sep 17, 2020
d7681d3
Broadcast via 1 element list
lenablind Sep 17, 2020
216859d
Slicing approach - distributed case, axis=None
lenablind Sep 17, 2020
c428f61
Replaced repeats.resplit_(0) with slicing strategy
lenablind Sep 17, 2020
2d03365
Combine both broadcasts, change back to resplit_(0) for less error su…
lenablind Sep 18, 2020
8cd457d
More precise warnings, handling of globally empty input
lenablind Sep 18, 2020
fd75a78
reshape & flatten repeats if axis is None
lenablind Sep 18, 2020
ff956c7
Additional tests for different dtypes of a
lenablind Sep 18, 2020
daf0cee
Additional tests
lenablind Sep 18, 2020
ac9e885
Adapted docstring
lenablind Sep 18, 2020
accb241
Update of CHANGELOG
lenablind Sep 18, 2020
f77b379
Merge branch 'master' into features/179-repeat
coquelin77 Sep 21, 2020
b75ce4c
Resplit out of place, warnings instead of (adapted) print
lenablind Sep 21, 2020
40e62db
Replaced array_equal with all() equivalent
lenablind Sep 22, 2020
7c8fafc
Solved merge conflict
lenablind Oct 26, 2020
229585a
Adaption to requested changes
lenablind Oct 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
- [#664](https://github.com/helmholtz-analytics/heat/pull/664) New feature / enhancement: distributed `random.random_sample`, `random.random`, `random.sample`, `random.ranf`, `random.random_integer`
- [#666](https://github.com/helmholtz-analytics/heat/pull/666) New feature: distributed prepend/append for `diff()`.
- [#667](https://github.com/helmholtz-analytics/heat/pull/667) Enhancement `reshape`: rename axis parameter
- [#674](https://github.com/helmholtz-analytics/heat/pull/674) New feature: `repeat`
- [#670](https://github.com/helmholtz-analytics/heat/pull/670) New Feature: distributed `bincount()`
- [#672](https://github.com/helmholtz-analytics/heat/pull/672) Bug / Enhancement: Remove `MPIRequest.wait()`, rewrite calls with capital letters. lower case `wait()` now falls back to the `mpi4py` function

Expand Down
258 changes: 257 additions & 1 deletion heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from . import types
from . import _operations


__all__ = [
"column_stack",
"concatenate",
Expand All @@ -26,6 +25,7 @@
"flipud",
"hstack",
"pad",
"repeat",
"reshape",
"resplit",
"rot90",
Expand Down Expand Up @@ -1238,6 +1238,262 @@ def pad(array, pad_width, mode="constant", constant_values=0):
return padded_tensor


def repeat(a, repeats, axis=None):
"""
Creates a new DNDarray by repeating elements of array a.

Parameters
----------
a : array_like (i.e. int, float, or tuple/ list/ np.ndarray/ ht.DNDarray of ints/floats)
Array containing the elements to be repeated.
repeats : int, or 1-dimensional/ DNDarray/ np.ndarray/ list/ tuple of ints
The number of repetitions for each element, indicates broadcast if int or array_like of 1 element.
In this case, the given value is broadcasted to fit the shape of the given axis.
Otherwise, its length must be the same as a in the specified axis. To put it differently, the
amount of repetitions has to be determined for each element in the corresponding dimension
(or in all dimensions if axis is None).
axis: int, optional
The axis along which to repeat values. By default, use the flattened input array and return a flat output
array.

Returns
-------
repeated_array : DNDarray
Output DNDarray which has the same shape as `a`, except along the given axis.
If axis is None, repeated_array will be a flattened DNDarray.

Examples
--------
>>> ht.repeat(3, 4)
DNDarray([3, 3, 3, 3])

>>> x = ht.array([[1,2],[3,4]])
>>> ht.repeat(x, 2)
DNDarray([1, 1, 2, 2, 3, 3, 4, 4])

>>> x = ht.array([[1,2],[3,4]])
>>> ht.repeat(x, [0, 1, 2, 0])
DNDarray([2, 3, 3])

>>> ht.repeat(x, [1,2], axis=0)
DNDarray([[1, 2],
[3, 4],
[3, 4]])
"""

# sanitation `a`
if not isinstance(a, dndarray.DNDarray):
if isinstance(a, (int, float)):
a = factories.array([a])
elif isinstance(a, (tuple, list, np.ndarray)):
a = factories.array(a)
else:
raise TypeError(
"`a` must be a ht.DNDarray, np.ndarray, list, tuple, integer, or float, currently: {}".format(
type(a)
)
)

# sanitation `axis`
if axis is not None and not isinstance(axis, int):
raise TypeError("`axis` must be an integer or None, currently: {}".format(type(axis)))

if axis is not None and (axis >= len(a.shape) or axis < 0):
raise ValueError(
"Invalid input for `axis`. Value has to be either None or between 0 and {}, not {}.".format(
len(a.shape) - 1, axis
)
)

# sanitation `repeats`
if not isinstance(repeats, (int, list, tuple, np.ndarray, dndarray.DNDarray)):
raise TypeError(
"`repeats` must be an integer, list, tuple, np.ndarray or ht.DNDarray of integers, currently: {}".format(
type(repeats)
)
)

# no broadcast implied
if not isinstance(repeats, int):
# make sure everything inside `repeats` is int
if isinstance(repeats, dndarray.DNDarray):
if repeats.dtype == types.int64:
pass
elif types.can_cast(repeats.dtype, types.int64):
repeats = factories.array(
repeats,
dtype=types.int64,
is_split=repeats.split,
device=repeats.device,
comm=repeats.comm,
)
else:
raise TypeError(
"Invalid dtype for ht.DNDarray `repeats`. Has to be integer,"
" but was {}".format(repeats.dtype)
)
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
elif isinstance(repeats, np.ndarray):
if not types.can_cast(repeats.dtype.type, types.int64):
raise TypeError(
"Invalid dtype for np.ndarray `repeats`. Has to be integer,"
" but was {}".format(repeats.dtype.type)
)
repeats = factories.array(
repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm
)
# invalid list/tuple
elif not all(isinstance(r, int) for r in repeats):
raise TypeError(
"Invalid type within `repeats`. All components of `repeats` must be integers."
)
# valid list/tuple
else:
repeats = factories.array(
repeats, dtype=types.int64, is_split=None, device=a.device, comm=a.comm
)

# check `repeats` is not empty
if repeats.gnumel == 0:
raise ValueError("Invalid input for `repeats`. `repeats` must contain data.")

# check `repeats` is 1-dimensional
if len(repeats.shape) != 1:
raise ValueError(
"Invalid input for `repeats`. `repeats` must be a 1d-object or integer, but "
"was {}-dimensional.".format(len(repeats.shape))
)

# start of algorithm

if 0 in a.gshape:
return a

# Broadcast (via int or 1-element DNDarray)
if isinstance(repeats, int) or repeats.gnumel == 1:
if axis is None and a.split is not None and a.split != 0:
warnings.warn(
"If axis is None, `a` has to be split along axis 0 (not {}) if distributed.\n`a` will be "
"copied with new split axis 0.".format(a.split)
)
a = resplit(a, 0)
if isinstance(repeats, int):
repeated_array_torch = torch.repeat_interleave(a._DNDarray__array, repeats, axis)
else:
if repeats.split is not None:
warnings.warn(
"For broadcast via array_like repeats, `repeats` must not be "
"distributed (along axis {}).\n`repeats` will be "
"copied with new split axis None.".format(repeats.split)
)
repeats = resplit(repeats, None)
repeated_array_torch = torch.repeat_interleave(
a._DNDarray__array, repeats._DNDarray__array, axis
)
# No broadcast
else:
# check if the data chunks of `repeats` and/or `a` have to be (re)distributed before call of torch function.

# UNDISTRIBUTED CASE (a not distributed)
if a.split is None:
if repeats.split is not None:
warnings.warn(
"If `a` is undistributed, `repeats` also has to be undistributed (not split along axis {}).\n`repeats` will be copied "
"with new split axis None.".format(repeats.split)
)
repeats = resplit(repeats, None)

# Check correct input
if axis is None:
# check matching shapes (repetition defined for every element)
if a.gnumel != repeats.gnumel:
raise ValueError(
"Invalid input. Sizes of flattened `a` ({}) and `repeats` ({}) are not same. "
"Please revise your definition specifying repetitions for all elements "
"of the DNDarray `a` or replace repeats with a single"
" scalar.".format(a.gnumel, repeats.gnumel)
)
# axis is not None
elif a.lshape[axis] != repeats.lnumel:
raise ValueError(
"Invalid input. Amount of elements of `repeats` ({}) and of `a` in the specified axis ({}) "
"are not the same. Please revise your definition specifying repetitions for all elements "
"of the DNDarray `a` or replace `repeats` with a single scalar".format(
repeats.lnumel, a.lshape[axis]
)
)
# DISTRIBUTED CASE (a distributed)
else:
if axis is None:
if a.gnumel != repeats.gnumel:
raise ValueError(
"Invalid input. Sizes of flattened `a` ({}) and `repeats` ({}) are not same. "
"Please revise your definition specifying repetitions for all elements "
"of the DNDarray `a` or replace `repeats` with a single"
" scalar.".format(a.gnumel, repeats.gnumel)
)

if a.split != 0:
warnings.warn(
"If `axis` is None, `a` has to be split along axis 0 (not {}) if distributed.\n`a` will be copied"
" with new split axis 0.".format(a.split)
)
a = resplit(a, 0)

repeats = repeats.reshape(a.gshape)
if repeats.split != 0:
warnings.warn(
"If `axis` is None, `repeats` has to be split along axis 0 (not {}) if distributed.\n`repeats` will be copied"
" with new split axis 0.".format(repeats.split)
)
repeats = resplit(repeats, 0)
flatten_repeats_t = torch.flatten(repeats._DNDarray__array)
repeats = factories.array(
flatten_repeats_t,
is_split=repeats.split,
device=repeats.device,
comm=repeats.comm,
)

# axis is not None
else:
if a.split == axis:
if repeats.split != 0:
warnings.warn(
"If `axis` equals `a.split`, `repeats` has to be split along axis 0 (not {}) if distributed.\n"
"`repeats` will be copied with new split axis 0".format(repeats.split)
)
repeats = resplit(repeats, 0)

# a.split != axis
else:
if repeats.split is not None:
warnings.warn(
"If `axis` != `a.split`, `repeast` must not be distributed (along axis {}).\n`repeats` will be copied with new"
" split axis None.".format(repeats.split)
)
repeats = resplit(repeats, None)

if a.lshape[axis] != repeats.lnumel:
raise ValueError(
"Invalid input. Amount of elements of `repeats` ({}) and of `a` in the specified axis ({}) "
"are not the same. Please revise your definition specifying repetitions for all elements "
"of the DNDarray `a` or replace `repeats` with a single scalar".format(
repeats.lnumel, a.lshape[axis]
)
)

repeated_array_torch = torch.repeat_interleave(
a._DNDarray__array, repeats._DNDarray__array, axis
)

repeated_array = factories.array(
repeated_array_torch, dtype=a.dtype, is_split=a.split, device=a.device, comm=a.comm
)
repeated_array.balance_()

return repeated_array


def reshape(a, shape, new_split=None):
"""
Returns a tensor with the same data and number of elements as a, but with the specified shape.
Expand Down
1 change: 1 addition & 0 deletions heat/core/tests/test_factories.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import torch

import heat as ht
Expand Down
Loading