Skip to content

Commit

Permalink
feat: Mean() works with TimeDelta(), #761
Browse files Browse the repository at this point in the history
  • Loading branch information
jpmckinney committed Oct 4, 2023
1 parent 469ee27 commit 6ef5884
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
-----

* feat: Lowercase the ``null_values`` provided to individual data types, since all comparisons to ``null_values`` are case-insensitive. (#770)
* feat: :class:`.Mean` works with :class:`.TimeDelta`. (#761)
* fix: Allow consecutive calls to :meth:`.Table.group_by`. (#765)

1.7.1 - Jan 4, 2023
Expand Down
11 changes: 7 additions & 4 deletions agate/aggregations/mean.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from agate.aggregations.base import Aggregation
from agate.aggregations.has_nulls import HasNulls
from agate.aggregations.sum import Sum
from agate.data_types import Number
from agate.data_types import Number, TimeDelta
from agate.exceptions import DataTypeError
from agate.warns import warn_null_calculation

Expand All @@ -18,13 +18,16 @@ def __init__(self, column_name):
self._sum = Sum(column_name)

def get_aggregate_data_type(self, table):
return Number()
column = table.columns[self._column_name]

if isinstance(column.data_type, (Number, TimeDelta)):
return column.data_type

def validate(self, table):
column = table.columns[self._column_name]

if not isinstance(column.data_type, Number):
raise DataTypeError('Mean can only be applied to columns containing Number data.')
if not isinstance(column.data_type, (Number, TimeDelta)):
raise DataTypeError('Mean can only be applied to columns containing Number or TimeDelta data.')

has_nulls = HasNulls(self._column_name).run(table)

Expand Down
29 changes: 26 additions & 3 deletions tests/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,11 +184,13 @@ def setUp(self):
self.table = Table(self.rows, ['test', 'null'], [DateTime(), DateTime()])

self.time_delta_rows = [
[datetime.timedelta(seconds=10), None],
[datetime.timedelta(seconds=20), None],
[datetime.timedelta(seconds=10), datetime.timedelta(seconds=15), None],
[datetime.timedelta(seconds=20), None, None],
]

self.time_delta_table = Table(self.time_delta_rows, ['test', 'null'], [TimeDelta(), TimeDelta()])
self.time_delta_table = Table(
self.time_delta_rows, ['test', 'mixed', 'null'], [TimeDelta(), TimeDelta(), TimeDelta()]
)

def test_min(self):
self.assertIsInstance(Min('test').get_aggregate_data_type(self.table), DateTime)
Expand Down Expand Up @@ -216,6 +218,27 @@ def test_max_time_delta(self):
Max('test').validate(self.time_delta_table)
self.assertEqual(Max('test').run(self.time_delta_table), datetime.timedelta(0, 20))

def test_mean(self):
with self.assertWarns(NullCalculationWarning):
Mean('mixed').validate(self.time_delta_table)

Mean('test').validate(self.time_delta_table)

self.assertEqual(Mean('test').run(self.time_delta_table), datetime.timedelta(seconds=15))

def test_mean_all_nulls(self):
self.assertIsNone(Mean('null').run(self.time_delta_table))

def test_mean_with_nulls(self):
warnings.simplefilter('ignore')

try:
Mean('mixed').validate(self.time_delta_table)
finally:
warnings.resetwarnings()

self.assertAlmostEqual(Mean('mixed').run(self.time_delta_table), datetime.timedelta(seconds=15))

def test_sum(self):
self.assertIsInstance(Sum('test').get_aggregate_data_type(self.time_delta_table), TimeDelta)
Sum('test').validate(self.time_delta_table)
Expand Down

0 comments on commit 6ef5884

Please sign in to comment.