Skip to content

Commit

Permalink
ENH: add Cython nth/last functions, vbenchmarks. close #1043
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed May 12, 2012
1 parent 59f0ee7 commit a98035c
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 18 deletions.
2 changes: 2 additions & 0 deletions RELEASE.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ pandas 0.8.0
sense indexing/selection functionality
- Series/DataFrame.update methods, in-place variant of combine_first (#961)
- Add ``match`` function to API (#502)
- Add Cython-optimized first, last, min, max, prod functions to GroupBy (#994,
#1043)

**Improvements to existing features**

Expand Down
32 changes: 26 additions & 6 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ def f(self):

return f

def _first_compat(x, axis=0):
x = np.asarray(x)
x = x[com.notnull(x)]
if len(x) == 0:
return np.nan
return x[0]

def _last_compat(x, axis=0):
x = np.asarray(x)
x = x[com.notnull(x)]
if len(x) == 0:
return np.nan
return x[-1]


class GroupBy(object):
"""
Expand Down Expand Up @@ -314,6 +328,8 @@ def size(self):
prod = _groupby_function('prod', 'prod', np.prod)
min = _groupby_function('min', 'min', np.min)
max = _groupby_function('max', 'max', np.max)
first = _groupby_function('first', 'first', _first_compat)
last = _groupby_function('last', 'last', _last_compat)

def ohlc(self):
"""
Expand All @@ -323,11 +339,11 @@ def ohlc(self):
"""
return self._cython_agg_general('ohlc')

def last(self):
return self.nth(-1)
# def last(self):
# return self.nth(-1)

def first(self):
return self.nth(0)
# def first(self):
# return self.nth(0)

def nth(self, n):
def picker(arr):
Expand Down Expand Up @@ -621,7 +637,9 @@ def get_group_levels(self):
'max' : lib.group_max,
'mean' : lib.group_mean,
'var' : lib.group_var,
'std' : lib.group_var
'std' : lib.group_var,
'first': lambda a, b, c, d: lib.group_nth(a, b, c, d, 1),
'last': lib.group_last
}

_cython_transforms = {
Expand Down Expand Up @@ -858,7 +876,9 @@ def names(self):
'max' : lib.group_max_bin,
'var' : lib.group_var_bin,
'std' : lib.group_var_bin,
'ohlc' : lib.group_ohlc
'ohlc' : lib.group_ohlc,
'first': lambda a, b, c, d: lib.group_nth_bin(a, b, c, d, 1),
'last': lib.group_last_bin
}

_name_functions = {
Expand Down
183 changes: 183 additions & 0 deletions pandas/src/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,188 @@ def group_prod(ndarray[float64_t, ndim=2] out,
else:
out[i, j] = prodx[i, j]

#----------------------------------------------------------------------
# first, nth, last

@cython.boundscheck(False)
@cython.wraparound(False)
def group_nth(ndarray[float64_t, ndim=2] out,
ndarray[int64_t] counts,
ndarray[float64_t, ndim=2] values,
ndarray[int64_t] labels, int64_t rank):
'''
Only aggregates on axis=0
'''
cdef:
Py_ssize_t i, j, N, K, lab
float64_t val, count
ndarray[float64_t, ndim=2] resx
ndarray[int64_t, ndim=2] nobs

nobs = np.zeros((<object> out).shape, dtype=np.int64)
resx = np.empty_like(out)

N, K = (<object> values).shape

for i in range(N):
lab = labels[i]
if lab < 0:
continue

counts[lab] += 1
for j in range(K):
val = values[i, j]

# not nan
if val == val:
nobs[lab, j] += 1
if nobs[lab, j] == rank:
resx[lab, j] = val

for i in range(len(counts)):
for j in range(K):
if nobs[i, j] == 0:
out[i, j] = nan
else:
out[i, j] = resx[i, j]


@cython.boundscheck(False)
@cython.wraparound(False)
def group_nth_bin(ndarray[float64_t, ndim=2] out,
ndarray[int64_t] counts,
ndarray[float64_t, ndim=2] values,
ndarray[int64_t] bins, int64_t rank):
'''
Only aggregates on axis=0
'''
cdef:
Py_ssize_t i, j, N, K, ngroups, b
float64_t val, count
ndarray[float64_t, ndim=2] resx, nobs

nobs = np.zeros_like(out)
resx = np.empty_like(out)

if bins[len(bins) - 1] == len(values):
ngroups = len(bins)
else:
ngroups = len(bins) + 1

N, K = (<object> values).shape

b = 0
for i in range(N):
while b < ngroups - 1 and i >= bins[b]:
b += 1

counts[b] += 1
for j in range(K):
val = values[i, j]

# not nan
if val == val:
nobs[b, j] += 1
if nobs[b, j] == rank:
resx[b, j] = val

for i in range(ngroups):
for j in range(K):
if nobs[i, j] == 0:
out[i, j] = nan
else:
out[i, j] = resx[i, j]

@cython.boundscheck(False)
@cython.wraparound(False)
def group_last(ndarray[float64_t, ndim=2] out,
ndarray[int64_t] counts,
ndarray[float64_t, ndim=2] values,
ndarray[int64_t] labels):
'''
Only aggregates on axis=0
'''
cdef:
Py_ssize_t i, j, N, K, lab
float64_t val, count
ndarray[float64_t, ndim=2] resx
ndarray[int64_t, ndim=2] nobs

nobs = np.zeros((<object> out).shape, dtype=np.int64)
resx = np.empty_like(out)

N, K = (<object> values).shape

for i in range(N):
lab = labels[i]
if lab < 0:
continue

counts[lab] += 1
for j in range(K):
val = values[i, j]

# not nan
if val == val:
nobs[lab, j] += 1
resx[lab, j] = val

for i in range(len(counts)):
for j in range(K):
if nobs[i, j] == 0:
out[i, j] = nan
else:
out[i, j] = resx[i, j]


@cython.boundscheck(False)
@cython.wraparound(False)
def group_last_bin(ndarray[float64_t, ndim=2] out,
ndarray[int64_t] counts,
ndarray[float64_t, ndim=2] values,
ndarray[int64_t] bins):
'''
Only aggregates on axis=0
'''
cdef:
Py_ssize_t i, j, N, K, ngroups, b
float64_t val, count
ndarray[float64_t, ndim=2] resx, nobs

nobs = np.zeros_like(out)
resx = np.empty_like(out)

if bins[len(bins) - 1] == len(values):
ngroups = len(bins)
else:
ngroups = len(bins) + 1

N, K = (<object> values).shape

b = 0
for i in range(N):
while b < ngroups - 1 and i >= bins[b]:
b += 1

counts[b] += 1
for j in range(K):
val = values[i, j]

# not nan
if val == val:
nobs[b, j] += 1
resx[b, j] = val

for i in range(ngroups):
for j in range(K):
if nobs[i, j] == 0:
out[i, j] = nan
else:
out[i, j] = resx[i, j]

#----------------------------------------------------------------------
# group_min, group_max


@cython.boundscheck(False)
@cython.wraparound(False)
Expand Down Expand Up @@ -787,6 +969,7 @@ def group_min_bin(ndarray[float64_t, ndim=2] out,
else:
out[i, j] = minx[i, j]


@cython.boundscheck(False)
@cython.wraparound(False)
def group_max_bin(ndarray[float64_t, ndim=2] out,
Expand Down
22 changes: 10 additions & 12 deletions pandas/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,25 +121,23 @@ def test_basic(self):
# corner cases
self.assertRaises(Exception, grouped.aggregate, lambda x: x * 2)

def test_first_last_nth(self):
# tests for first / last / nth
grouped = self.df.groupby('A')
first = grouped.first()
expected = grouped.get_group('bar')
expected = expected.xs(expected.index[0])[1:]
expected.name ='bar'
assert_series_equal(first.xs('bar'), expected)
expected = self.df.ix[[1, 0], ['C', 'D']]
expected.index = ['bar', 'foo']
assert_frame_equal(first, expected)

last = grouped.last()
expected = grouped.get_group('bar')
expected = expected.xs(expected.index[-1])[1:]
expected.name ='bar'
assert_series_equal(last.xs('bar'), expected)
expected = self.df.ix[[5, 7], ['C', 'D']]
expected.index = ['bar', 'foo']
assert_frame_equal(last, expected)

nth = grouped.nth(1)
expected = grouped.get_group('foo')
expected = expected.xs(expected.index[1])[1:]
expected.name ='foo'
assert_series_equal(nth.xs('foo'), expected)
expected = self.df.ix[[3, 2], ['B', 'C', 'D']]
expected.index = ['bar', 'foo']
assert_frame_equal(nth, expected)

def test_empty_groups(self):
# GH # 1048
Expand Down
17 changes: 17 additions & 0 deletions vb_suite/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,20 @@ def f():

groupby_apply_dict_return = Benchmark('data.groupby(labels).apply(f)',
setup, start_date=datetime(2011, 12, 15))

#----------------------------------------------------------------------
# First / last functions

setup = common_setup + """
labels = np.arange(10000).repeat(10)
data = Series(randn(len(labels)))
data[::3] = np.nan
data[1::3] = np.nan
labels = labels.take(np.random.permutation(len(labels)))
"""

groupby_first = Benchmark('data.groupby(labels).first()', setup,
start_date=datetime(2012, 5, 1))

groupby_last = Benchmark('data.groupby(labels).last()', setup,
start_date=datetime(2012, 5, 1))

0 comments on commit a98035c

Please sign in to comment.