Skip to content

Commit

Permalink
Merge pull request #74 from helmholtz-analytics/features/reduceops
Browse files Browse the repository at this point in the history
Addressing issue #72
  • Loading branch information
ClaudiaComito authored Dec 8, 2018
2 parents 0d6596c + 150d882 commit 8ecc2a7
Show file tree
Hide file tree
Showing 5 changed files with 561 additions and 183 deletions.
256 changes: 194 additions & 62 deletions heat/core/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
__all__ = [
'abs',
'absolute',
'argmin',
'clip',
'copy',
'exp',
Expand All @@ -18,6 +19,7 @@
'min',
'sin',
'sqrt',
'sum',
'tril',
'triu'
]
Expand Down Expand Up @@ -48,7 +50,8 @@ def abs(x, out=None, dtype=None):

absolute_values = __local_operation(torch.abs, x, out)
if dtype is not None:
absolute_values._tensor__array = absolute_values._tensor__array.type(dtype.torch_type())
absolute_values._tensor__array = absolute_values._tensor__array.type(
dtype.torch_type())
absolute_values._tensor__dtype = dtype

return absolute_values
Expand Down Expand Up @@ -79,15 +82,58 @@ def absolute(x, out=None, dtype=None):
return abs(x, out, dtype)


def argmin(x, axis):
# TODO: document me
# TODO: test me
# TODO: sanitize input
# TODO: make me more numpy API complete
# TODO: Fix me, I am not reduce_op.MIN!
#
_, argmin_axis = x._tensor__array.min(dim=axis, keepdim=True)
return __reduce_op(x, argmin_axis, MPI.MIN, axis)
def argmin(x, axis=None):
'''
Returns the indices of the minimum values along an axis.
Parameters:
----------
x : ht.tensor
Input array.
axis : int, optional
By default, the index is into the flattened tensor, otherwise along the specified axis.
# TODO out : array, optional
If provided, the result will be inserted into this tensor. It should be of the appropriate shape and dtype.
Returns:
-------
index_tensor : ht.tensor of ints
Array of indices into the array. It has the same shape as x.shape with the dimension along axis removed.
Examples:
--------
>>> a = ht.randn(3,3)
>>> a
tensor([[-1.7297, 0.2541, -0.1044],
[ 1.0865, -0.4415, 1.3716],
[-0.0827, 1.0215, -2.0176]])
>>> ht.argmin(a)
tensor([8])
>>> ht.argmin(a, axis=0)
tensor([[0, 1, 2]])
>>> ht.argmin(a, axis=1)
tensor([[0],
[1],
[2]])
'''

if axis is None:
# TEMPORARY SOLUTION! TODO: implementation for axis=None, distributed tensor
# perform sanitation
if not isinstance(x, tensor.tensor):
raise TypeError(
'expected x to be a ht.tensor, but was {}'.format(type(x)))
axis = stride_tricks.sanitize_axis(x.shape, axis)
out = torch.reshape(torch.argmin(x._tensor__array), (1,))
return tensor.tensor(out, out.shape, types.canonical_heat_type(out.dtype), split=None, comm=x.comm)

out = __reduce_op(x, torch.min, MPI.MIN, axis)._tensor__array[1]
return tensor.tensor(out, out.shape, types.canonical_heat_type(out.dtype), x._tensor__split, comm=x.comm)


def clip(a, a_min, a_max, out=None):
Expand Down Expand Up @@ -229,59 +275,92 @@ def log(x, out=None):

def max(x, axis=None):
""""
Return the maximum of an array or maximum along an axis.
Return a tuple containing:
- the maximum of an array or maximum along an axis;
- indices of maxima
Parameters
----------
a : ht.tensor
Input data.
axis : None or int, optional
Axis or axes along which to operate. By default, flattened input is used.
#TODO: out : ht.tensor, optional
Alternative output array in which to place the result. Must be of the same shape and buffer length as the expected output.
Axis or axes along which to operate. By default, flattened input is used.
# TODO: out : ht.tensor, optional
Tuple of two output tensors (max, max_indices). Must be of the same shape and buffer length as the expected output.
#TODO: initial : scalar, optional
# TODO: initial : scalar, optional
The minimum value of an output element. Must be present to allow computation on empty slice.
"""
#perform sanitation:
axis = stride_tricks.sanitize_axis(x.shape,axis)

if axis is not None:
max_axis, _ = x._tensor__array.max(axis, keepdim=True)
else:
return x._tensor__array.max()
return __reduce_op(x, max_axis, MPI.MAX, axis)
Examples
--------
>>> a = ht.float32([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]
])
>>> ht.max(a)
tensor([12.])
>>> ht.min(a, axis=0)
(tensor([[10., 11., 12.]]), tensor([[3, 3, 3]]))
>>> ht.min(a, axis=1)
(tensor([[ 3.],
[ 6.],
[ 9.],
[12.]]), tensor([[2],
[2],
[2],
[2]]))
"""
return __reduce_op(x, torch.max, MPI.MAX, axis)


def min(x, axis=None):
""""
Return the minimum of an array or minimum along an axis.
Return a tuple containing:
- the minimum of an array or minimum along an axis;
- indices of minima
Parameters
----------
a : ht.tensor
Input data.
axis : None or int
Axis or axes along which to operate. By default, flattened input is used.
#TODO: out : ht.tensor, optional
Alternative output array in which to place the result. Must be of the same shape and buffer length as the expected output.
Axis or axes along which to operate. By default, flattened input is used.
#TODO: initial : scalar, optional
# TODO: out : ht.tensor, optional
Tuple of two output tensors (min, min_indices). Must be of the same shape and buffer length as the expected output.
# TODO: initial : scalar, optional
The maximum value of an output element. Must be present to allow computation on empty slice.
"""
#perform sanitation:
axis = stride_tricks.sanitize_axis(x.shape,axis)
if axis is not None:
min_axis, _ = x._tensor__array.min(axis, keepdim=True)
else:
return x._tensor__array.min()
return __reduce_op(x, min_axis, MPI.MIN, axis)
Examples
--------
>>> a = ht.float32([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]
])
>>> ht.min(a)
tensor([1.])
>>> ht.min(a, axis=0)
(tensor([[1., 2., 3.]]), tensor([[0, 0, 0]]))
>>> ht.min(a, axis=1)
(tensor([[ 1.],
[ 4.],
[ 7.],
[10.]]), tensor([[0],
[0],
[0],
[0]]))
"""
return __reduce_op(x, torch.min, MPI.MIN, axis)


def sin(x, out=None):
Expand Down Expand Up @@ -310,17 +389,17 @@ def sin(x, out=None):
return __local_operation(torch.sin, x, out)


def sum(x, axis=None):
# TODO: document me
axis = stride_tricks.sanitize_axis(x.shape, axis)
if axis is not None:
sum_axis = x._tensor__array.sum(axis, keepdim=True)
else:
sum_axis = torch.reshape(x._tensor__array.sum(), (1,))
if not x.comm.is_distributed():
return tensor.tensor(sum_axis, (1,), types.canonical_heat_type(sum_axis.dtype), None, x.comm)
# def sum(x, axis=None):
# # TODO: document me
# axis = stride_tricks.sanitize_axis(x.shape, axis)
# if axis is not None:
# sum_axis = x._tensor__array.sum(axis, keepdim=True)
# else:
# sum_axis = torch.reshape(x._tensor__array.sum(), (1,))
# if not x.comm.is_distributed():
# return tensor.tensor(sum_axis, (1,), types.canonical_heat_type(sum_axis.dtype), None, x.comm)

return __reduce_op(x, sum_axis, MPI.SUM, axis)
# return __reduce_op(x, sum_axis, MPI.SUM, axis)


def sqrt(x, out=None):
Expand Down Expand Up @@ -351,6 +430,47 @@ def sqrt(x, out=None):
return __local_operation(torch.sqrt, x, out)


def sum(x, axis=None):
"""
Sum of array elements over a given axis.
Parameters
----------
x : ht.tensor
Input data.
axis : None or int, optional
Axis along which a sum is performed. The default, axis=None, will sum
all of the elements of the input array. If axis is negative it counts
from the last to the first axis.
Returns
-------
sum_along_axis : ht.tensor
An array with the same shape as self.__array except for the specified axis which
becomes one, e.g. a.shape = (1,2,3) => ht.ones((1,2,3)).sum(axis=1).shape = (1,1,3)
Examples
--------
>>> ht.sum(ht.ones(2))
tensor([2.])
>>> ht.sum(ht.ones((3,3)))
tensor([9.])
>>> ht.sum(ht.ones((3,3)).astype(ht.int))
tensor([9])
>>> ht.sum(ht.ones((3,2,1)), axis=-3)
tensor([[[3.],
[3.]]])
"""

# TODO: make me more numpy API complete

return __reduce_op(x, torch.sum, MPI.SUM, axis)


def __local_operation(operation, x, out):
"""
Generic wrapper for local operations, which do not require communication. Accepts the actual operation function as
Expand Down Expand Up @@ -379,9 +499,11 @@ def __local_operation(operation, x, out):
"""
# perform sanitation
if not isinstance(x, tensor.tensor):
raise TypeError('expected x to be a ht.tensor, but was {}'.format(type(x)))
raise TypeError(
'expected x to be a ht.tensor, but was {}'.format(type(x)))
if out is not None and not isinstance(out, tensor.tensor):
raise TypeError('expected out to be None or an ht.tensor, but was {}'.format(type(out)))
raise TypeError(
'expected out to be None or an ht.tensor, but was {}'.format(type(out)))

# infer the output type of the tensor
# we need floating point numbers here, due to PyTorch only providing sqrt() implementation for float32/64
Expand All @@ -403,7 +525,8 @@ def __local_operation(operation, x, out):

# do an inplace operation into a provided buffer
casted = x._tensor__array.type(torch_type)
operation(casted.repeat(multiples) if needs_repetition else casted, out=out._tensor__array)
operation(casted.repeat(multiples)
if needs_repetition else casted, out=out._tensor__array)
return out


Expand Down Expand Up @@ -441,7 +564,8 @@ def __tri_op(m, k, op):
try:
k = int(k)
except ValueError:
raise TypeError('Expected k to be integral, but was {}'.format(type(k)))
raise TypeError(
'Expected k to be integral, but was {}'.format(type(k)))

# chunk the global shape of the tensor to obtain the offset compared to the other ranks
offset, _, _ = m.comm.chunk(m.shape, m.split)
Expand Down Expand Up @@ -525,22 +649,30 @@ def triu(m, k=0):
return __tri_op(m, k, torch.triu)


def __reduce_op(x, partial, op, axis):
def __reduce_op(x, partial_op, op, axis):
# TODO: document me
# TODO: test me
# TODO: make me more numpy API complete
# TODO: e.g. allow axis to be a tuple, allow for "initial"
# TODO: implement type promotion

# perform sanitation
if not isinstance(x, tensor.tensor):
raise TypeError('expected x to be a ht.tensor, but was {}'.format(type(x)))
raise TypeError(
'expected x to be a ht.tensor, but was {}'.format(type(x)))
# no further checking needed, sanitize axis will raise the proper exceptions
axis = stride_tricks.sanitize_axis(x.shape, axis)

if axis is None:
partial = torch.reshape(partial_op(x._tensor__array), (1,))
output_shape = partial.shape
else:
partial = partial_op(x._tensor__array, axis, keepdim=True)
# TODO: verify if this works for negative split axis
output_shape = x.gshape[:axis] + (1,) + x.gshape[axis + 1:]

if x.comm.is_distributed() and (axis is None or axis == x.split):
x.comm.Allreduce(MPI.IN_PLACE, partial, op)
return tensor.tensor(partial, partial.shape, types.canonical_heat_type(partial.dtype), split=None, comm=x.comm)
x.comm.Allreduce(MPI.IN_PLACE, partial[0], op)
return tensor.tensor(partial, output_shape, types.canonical_heat_type(partial[0].dtype), split=None, comm=x.comm)

# TODO: verify if this works for negative split axis
output_shape = x.shape[:axis] + (1,) + x.shape[axis + 1:]
return tensor.tensor(partial, output_shape, types.canonical_heat_type(partial.dtype), split=None, comm=x.comm)
return tensor.tensor(partial, output_shape, types.canonical_heat_type(partial[0].dtype), split=None, comm=x.comm)
Loading

0 comments on commit 8ecc2a7

Please sign in to comment.