Skip to content

Commit

Permalink
Added documentation for the generic reduce_op. Fixes issues #102.
Browse files Browse the repository at this point in the history
  • Loading branch information
Markus-Goetz committed Apr 2, 2020
1 parent b533215 commit faaf3ae
Showing 1 changed file with 38 additions and 6 deletions.
44 changes: 38 additions & 6 deletions heat/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from . import types

__all__ = []
__BOOLEAN_OPS = [MPI.LAND, MPI.LOR, MPI.BAND, MPI.BOR]


def __binary_op(operation, t1, t2):
Expand All @@ -32,7 +33,7 @@ def __binary_op(operation, t1, t2):
Returns
-------
result: ht.DNDarray
A tensor containing the results of element-wise operation.
A DNDarray containing the results of element-wise operation.
"""
if np.isscalar(t1):
try:
Expand Down Expand Up @@ -237,8 +238,39 @@ def __local_op(operation, x, out, no_cast=False, **kwargs):
return out


def __reduce_op(x, partial_op, reduction_op, **kwargs):
# TODO: document me Issue #102
def __reduce_op(x, partial_op, reduction_op, neutral=None, **kwargs):
"""
Generic wrapper for reduction operations, e.g. sum(), prod() etc. Performs a two-stage reduction. First, a partial
reduction is performed node-local that is combined into a global reduction result via an MPI_Op.
Parameters
----------
x : ht.DNDarray
The heat DNDarray on which to perform the reduction operation
partial_op: function
The function performing a partial reduction on the process-local data portion, e.g. sum() for implementing a
distributed mean() operation.
reduction_op: mpi4py.MPI.Op
The MPI operator for performing the full reduction based on the results returned by the partial_op function.
neutral: scalar
Neutral element for the reduction operation, i.e. an element that does not change the reductions operations
result. Required in cases where
Returns
-------
result: ht.DNDarray
A DNDarray containing the result of the reduction operation
Raises
------
TypeError
If the input or optional output parameter are not of type ht.DNDarray
ValueError
If the shape of the optional output parameters does not match the shape of the reduced result
"""
# perform sanitation
if not isinstance(x, dndarray.DNDarray):
raise TypeError("expected x to be a ht.DNDarray, but was {}".format(type(x)))
Expand All @@ -255,13 +287,14 @@ def __reduce_op(x, partial_op, reduction_op, **kwargs):

# if local tensor is empty, replace it with the identity element
if 0 in x.lshape and (axis is None or (x.split in axis)):
neutral = kwargs.get("neutral")
if neutral is None:
neutral = float("nan")
neutral_shape = x.lshape[:split] + (1,) + x.lshape[split + 1 :]
partial = torch.full(neutral_shape, fill_value=neutral, dtype=x._DNDarray__array.dtype)
else:
partial = x._DNDarray__array

#
if axis is None:
partial = partial_op(partial).reshape(-1)
output_shape = (1,)
Expand Down Expand Up @@ -294,8 +327,7 @@ def __reduce_op(x, partial_op, reduction_op, **kwargs):
x.comm.Allreduce(MPI.IN_PLACE, partial, reduction_op)

# if reduction_op is a Boolean operation, then resulting tensor is bool
boolean_ops = [MPI.LAND, MPI.LOR, MPI.BAND, MPI.BOR]
tensor_type = bool if reduction_op in boolean_ops else partial.dtype
tensor_type = bool if reduction_op in __BOOLEAN_OPS else partial.dtype

if out is not None:
out._DNDarray__array = partial
Expand Down

0 comments on commit faaf3ae

Please sign in to comment.