diff --git a/heat/core/statistics.py b/heat/core/statistics.py index 7608da7387..d6c0305a34 100644 --- a/heat/core/statistics.py +++ b/heat/core/statistics.py @@ -206,14 +206,12 @@ def average( Axis or axes along which to average ``x``. The default, ``axis=None``, will average over all of the elements of the input array. If axis is negative it counts from the last to the first axis. - #TODO Issue #351: If axis is a tuple of ints, averaging is performed on all of the axes - specified in the tuple instead of a single axis or all the axes as - before. weights : DNDarray, optional An array of weights associated with the values in ``x``. Each value in ``x`` contributes to the average according to its associated weight. The weights array can either be 1D (in which case its length must be the size of ``x`` along the given axis) or of the same shape as ``x``. + Weighted average over tuple axis requires weights array to be of the same shape as ``x``. If ``weights=None``, then all data in ``x`` are assumed to have a weight equal to one, the result is equivalent to :func:`mean`. returned : bool, optional @@ -269,7 +267,9 @@ def average( if axis is None: raise TypeError("Axis must be specified when shapes of x and weights differ.") elif isinstance(axis, tuple): - raise NotImplementedError("Weighted average over tuple axis not implemented yet.") + raise ValueError( + "Weighted average over tuple axis requires weights to be of the same shape as x." + ) if weights.ndim != 1: raise TypeError("1D weights expected when shapes of x and weights differ.") if weights.gshape[0] != x.gshape[axis]: diff --git a/heat/core/tests/test_statistics.py b/heat/core/tests/test_statistics.py index 922fd74aa8..2ce40a8201 100644 --- a/heat/core/tests/test_statistics.py +++ b/heat/core/tests/test_statistics.py @@ -306,7 +306,7 @@ def test_average(self): ht.average(random_5d, weights=random_weights.numpy(), axis=axis) with self.assertRaises(TypeError): ht.average(random_5d, weights=random_weights, axis=None) - with self.assertRaises(NotImplementedError): + with self.assertRaises(ValueError): ht.average(random_5d, weights=random_weights, axis=(1, 2)) random_weights = ht.random.randn(random_5d.gshape[axis], random_5d.gshape[axis + 1]) with self.assertRaises(TypeError):