Skip to content

Commit

Permalink
Moved tests from groupby to transform modules
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Mar 8, 2018
1 parent 4bc572b commit db8ea80
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 55 deletions.
55 changes: 0 additions & 55 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2062,61 +2062,6 @@ def test_rank_object_raises(self, ties_method, ascending, na_option,
ascending=ascending,
na_option=na_option, pct=pct)

@pytest.mark.parametrize("mix_groupings", [True, False])
@pytest.mark.parametrize("as_series", [True, False])
@pytest.mark.parametrize("val1,val2", [
('foo', 'bar'), (1, 2), (1., 2.)])
@pytest.mark.parametrize("fill_method,limit,exp_vals", [
("ffill", None,
[np.nan, np.nan, 'val1', 'val1', 'val1', 'val2', 'val2', 'val2']),
("ffill", 1,
[np.nan, np.nan, 'val1', 'val1', np.nan, 'val2', 'val2', np.nan]),
("bfill", None,
['val1', 'val1', 'val1', 'val2', 'val2', 'val2', np.nan, np.nan]),
("bfill", 1,
[np.nan, 'val1', 'val1', np.nan, 'val2', 'val2', np.nan, np.nan])
])
def test_group_fill_methods(self, mix_groupings, as_series, val1, val2,
fill_method, limit, exp_vals):
vals = [np.nan, np.nan, val1, np.nan, np.nan, val2, np.nan, np.nan]
_exp_vals = list(exp_vals)
# Overwrite placeholder values
for index, exp_val in enumerate(_exp_vals):
if exp_val == 'val1':
_exp_vals[index] = val1
elif exp_val == 'val2':
_exp_vals[index] = val2

# Need to modify values and expectations depending on the
# Series / DataFrame that we ultimately want to generate
if mix_groupings: # ['a', 'b', 'a, 'b', ...]
keys = ['a', 'b'] * len(vals)

def interweave(list_obj):
temp = list()
for x in list_obj:
temp.extend([x, x])

return temp

_exp_vals = interweave(_exp_vals)
vals = interweave(vals)
else: # ['a', 'a', 'a', ... 'b', 'b', 'b']
keys = ['a'] * len(vals) + ['b'] * len(vals)
_exp_vals = _exp_vals * 2
vals = vals * 2

df = DataFrame({'key': keys, 'val': vals})
if as_series:
result = getattr(
df.groupby('key')['val'], fill_method)(limit=limit)
exp = Series(_exp_vals, name='val')
assert_series_equal(result, exp)
else:
result = getattr(df.groupby('key'), fill_method)(limit=limit)
exp = DataFrame({'key': keys, 'val': _exp_vals})
assert_frame_equal(result, exp)

@pytest.mark.parametrize("agg_func", ['any', 'all'])
@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.parametrize("vals", [
Expand Down
87 changes: 87 additions & 0 deletions pandas/tests/groupby/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,3 +636,90 @@ def test_transform_numeric_ret(self, cols, exp, comp_func, agg_func):
exp = exp.astype('float')

comp_func(result, exp)

@pytest.mark.parametrize("mix_groupings", [True, False])
@pytest.mark.parametrize("as_series", [True, False])
@pytest.mark.parametrize("val1,val2", [
('foo', 'bar'), (1, 2), (1., 2.)])
@pytest.mark.parametrize("fill_method,limit,exp_vals", [
("ffill", None,
[np.nan, np.nan, 'val1', 'val1', 'val1', 'val2', 'val2', 'val2']),
("ffill", 1,
[np.nan, np.nan, 'val1', 'val1', np.nan, 'val2', 'val2', np.nan]),
("bfill", None,
['val1', 'val1', 'val1', 'val2', 'val2', 'val2', np.nan, np.nan]),
("bfill", 1,
[np.nan, 'val1', 'val1', np.nan, 'val2', 'val2', np.nan, np.nan])
])
def test_group_fill_methods(self, mix_groupings, as_series, val1, val2,
fill_method, limit, exp_vals):
vals = [np.nan, np.nan, val1, np.nan, np.nan, val2, np.nan, np.nan]
_exp_vals = list(exp_vals)
# Overwrite placeholder values
for index, exp_val in enumerate(_exp_vals):
if exp_val == 'val1':
_exp_vals[index] = val1
elif exp_val == 'val2':
_exp_vals[index] = val2

# Need to modify values and expectations depending on the
# Series / DataFrame that we ultimately want to generate
if mix_groupings: # ['a', 'b', 'a, 'b', ...]
keys = ['a', 'b'] * len(vals)

def interweave(list_obj):
temp = list()
for x in list_obj:
temp.extend([x, x])

return temp

_exp_vals = interweave(_exp_vals)
vals = interweave(vals)
else: # ['a', 'a', 'a', ... 'b', 'b', 'b']
keys = ['a'] * len(vals) + ['b'] * len(vals)
_exp_vals = _exp_vals * 2
vals = vals * 2

df = DataFrame({'key': keys, 'val': vals})
if as_series:
result = getattr(
df.groupby('key')['val'], fill_method)(limit=limit)
exp = Series(_exp_vals, name='val')
assert_series_equal(result, exp)
else:
result = getattr(df.groupby('key'), fill_method)(limit=limit)
exp = DataFrame({'key': keys, 'val': _exp_vals})
assert_frame_equal(result, exp)

@pytest.mark.parametrize("test_series", [True, False])
@pytest.mark.parametrize("periods,fill_method,limit", [
(1, 'ffill', None), (1, 'ffill', 1),
(1, 'bfill', None), (1, 'bfill', 1),
(-1, 'ffill', None), (-1, 'ffill', 1),
(-1, 'bfill', None), (-1, 'bfill', 1)])
def test_pct_change(self, test_series, periods, fill_method, limit):
vals = [np.nan, np.nan, 1, 2, 4, 10, np.nan, np.nan]
exp_vals = Series(vals).pct_change(periods=periods,
fill_method=fill_method,
limit=limit).tolist()

df = DataFrame({'key': ['a'] * len(vals) + ['b'] * len(vals),
'vals': vals * 2})
grp = df.groupby('key')

def get_result(grp_obj):
return grp_obj.pct_change(periods=periods,
fill_method=fill_method,
limit=limit)

if test_series:
exp = pd.Series(exp_vals * 2)
exp.name = 'vals'
grp = grp['vals']
result = get_result(grp)
tm.assert_series_equal(result, exp)
else:
exp = DataFrame({'vals': exp_vals * 2})
result = get_result(grp)
tm.assert_frame_equal(result, exp)

0 comments on commit db8ea80

Please sign in to comment.