From 6ef58845b0b4c2a71b33027770940ce5786106b2 Mon Sep 17 00:00:00 2001 From: James McKinney <26463+jpmckinney@users.noreply.github.com> Date: Tue, 3 Oct 2023 23:29:39 -0400 Subject: [PATCH] feat: Mean() works with TimeDelta(), #761 --- CHANGELOG.rst | 1 + agate/aggregations/mean.py | 11 +++++++---- tests/test_aggregations.py | 29 ++++++++++++++++++++++++++--- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 169c3290..53574945 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,7 @@ ----- * feat: Lowercase the ``null_values`` provided to individual data types, since all comparisons to ``null_values`` are case-insensitive. (#770) +* feat: :class:`.Mean` works with :class:`.TimeDelta`. (#761) * fix: Allow consecutive calls to :meth:`.Table.group_by`. (#765) 1.7.1 - Jan 4, 2023 diff --git a/agate/aggregations/mean.py b/agate/aggregations/mean.py index 107466f3..9fe27dc9 100644 --- a/agate/aggregations/mean.py +++ b/agate/aggregations/mean.py @@ -1,7 +1,7 @@ from agate.aggregations.base import Aggregation from agate.aggregations.has_nulls import HasNulls from agate.aggregations.sum import Sum -from agate.data_types import Number +from agate.data_types import Number, TimeDelta from agate.exceptions import DataTypeError from agate.warns import warn_null_calculation @@ -18,13 +18,16 @@ def __init__(self, column_name): self._sum = Sum(column_name) def get_aggregate_data_type(self, table): - return Number() + column = table.columns[self._column_name] + + if isinstance(column.data_type, (Number, TimeDelta)): + return column.data_type def validate(self, table): column = table.columns[self._column_name] - if not isinstance(column.data_type, Number): - raise DataTypeError('Mean can only be applied to columns containing Number data.') + if not isinstance(column.data_type, (Number, TimeDelta)): + raise DataTypeError('Mean can only be applied to columns containing Number or TimeDelta data.') has_nulls = HasNulls(self._column_name).run(table) diff --git a/tests/test_aggregations.py b/tests/test_aggregations.py index c7fd352a..c8dba876 100644 --- a/tests/test_aggregations.py +++ b/tests/test_aggregations.py @@ -184,11 +184,13 @@ def setUp(self): self.table = Table(self.rows, ['test', 'null'], [DateTime(), DateTime()]) self.time_delta_rows = [ - [datetime.timedelta(seconds=10), None], - [datetime.timedelta(seconds=20), None], + [datetime.timedelta(seconds=10), datetime.timedelta(seconds=15), None], + [datetime.timedelta(seconds=20), None, None], ] - self.time_delta_table = Table(self.time_delta_rows, ['test', 'null'], [TimeDelta(), TimeDelta()]) + self.time_delta_table = Table( + self.time_delta_rows, ['test', 'mixed', 'null'], [TimeDelta(), TimeDelta(), TimeDelta()] + ) def test_min(self): self.assertIsInstance(Min('test').get_aggregate_data_type(self.table), DateTime) @@ -216,6 +218,27 @@ def test_max_time_delta(self): Max('test').validate(self.time_delta_table) self.assertEqual(Max('test').run(self.time_delta_table), datetime.timedelta(0, 20)) + def test_mean(self): + with self.assertWarns(NullCalculationWarning): + Mean('mixed').validate(self.time_delta_table) + + Mean('test').validate(self.time_delta_table) + + self.assertEqual(Mean('test').run(self.time_delta_table), datetime.timedelta(seconds=15)) + + def test_mean_all_nulls(self): + self.assertIsNone(Mean('null').run(self.time_delta_table)) + + def test_mean_with_nulls(self): + warnings.simplefilter('ignore') + + try: + Mean('mixed').validate(self.time_delta_table) + finally: + warnings.resetwarnings() + + self.assertAlmostEqual(Mean('mixed').run(self.time_delta_table), datetime.timedelta(seconds=15)) + def test_sum(self): self.assertIsInstance(Sum('test').get_aggregate_data_type(self.time_delta_table), TimeDelta) Sum('test').validate(self.time_delta_table)