Skip to content

Commit

Permalink
Refactor tests for ev.aggregations (awslabs#3038)
Browse files Browse the repository at this point in the history
  • Loading branch information
lostella committed Nov 13, 2023
1 parent ca40b46 commit b31b202
Showing 1 changed file with 60 additions and 53 deletions.
113 changes: 60 additions & 53 deletions test/ev/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,61 +18,39 @@
from gluonts.ev import Mean, Sum
from gluonts.itertools import power_set

VALUE_STREAM = [
[
np.full((3, 5), np.nan),
np.full((3, 5), np.nan),
np.full((3, 5), np.nan),
],
[
np.array([[0, np.nan], [0, 0]]),
np.array([[0, 5], [-5, np.nan]]),
],
[
np.full(shape=(3, 3), fill_value=1),
np.full(shape=(1, 3), fill_value=4),
],
]

SUM_RES_AXIS_NONE = [
0,
0,
21,
]

SUM_RES_AXIS_0 = [
np.zeros(5),
np.array([-5, 5]),
np.array([7, 7, 7]),
]
SUM_RES_AXIS_1 = [
np.zeros(9),
np.array([0, 0, 5, -5]),
np.array([3, 3, 3, 12]),
]


MEAN_RES_AXIS_NONE = [
np.nan,
0,
1.75,
]

MEAN_RES_AXIS_0 = [
np.full(5, np.nan),
np.array([-1.25, 2.5]),
np.array([1.75, 1.75, 1.75]),
]
MEAN_RES_AXIS_1 = [
np.full(9, np.nan),
np.array([0, 0, 2.5, -5]),
np.array([1, 1, 1, 4]),
]


@pytest.mark.parametrize(
"value_stream, res_axis_none, res_axis_0, res_axis_1",
zip(VALUE_STREAM, SUM_RES_AXIS_NONE, SUM_RES_AXIS_0, SUM_RES_AXIS_1),
[
(
[
np.full((3, 5), np.nan),
np.full((3, 5), np.nan),
np.full((3, 5), np.nan),
],
0,
np.zeros(5),
np.zeros(9),
),
(
[
np.array([[0, np.nan], [0, 0]]),
np.array([[0, 5], [-5, np.nan]]),
],
0,
np.array([-5, 5]),
np.array([0, 0, 5, -5]),
),
(
[
np.full(shape=(3, 3), fill_value=1),
np.full(shape=(1, 3), fill_value=4),
],
21,
np.array([7, 7, 7]),
np.array([3, 3, 3, 12]),
),
],
)
def test_Sum(value_stream, res_axis_none, res_axis_0, res_axis_1):
for axis, expected_result in zip(
Expand All @@ -87,7 +65,36 @@ def test_Sum(value_stream, res_axis_none, res_axis_0, res_axis_1):

@pytest.mark.parametrize(
"value_stream, res_axis_none, res_axis_0, res_axis_1",
zip(VALUE_STREAM, MEAN_RES_AXIS_NONE, MEAN_RES_AXIS_0, MEAN_RES_AXIS_1),
[
(
[
np.full((3, 5), np.nan),
np.full((3, 5), np.nan),
np.full((3, 5), np.nan),
],
np.nan,
np.full(5, np.nan),
np.full(9, np.nan),
),
(
[
np.array([[0, np.nan], [0, 0]]),
np.array([[0, 5], [-5, np.nan]]),
],
0,
np.array([-1.25, 2.5]),
np.array([0, 0, 2.5, -5]),
),
(
[
np.full(shape=(3, 3), fill_value=1),
np.full(shape=(1, 3), fill_value=4),
],
1.75,
np.array([1.75, 1.75, 1.75]),
np.array([1, 1, 1, 4]),
),
],
)
def test_Mean(value_stream, res_axis_none, res_axis_0, res_axis_1):
for axis, expected_result in zip(
Expand Down

0 comments on commit b31b202

Please sign in to comment.