Skip to content

Commit

Permalink
Refactor groupby group_add from tempita to fused types (#24954)
Browse files Browse the repository at this point in the history
  • Loading branch information
noamher authored and jreback committed Feb 9, 2019
1 parent 0508d81 commit 2448e52
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 49 deletions.
51 changes: 51 additions & 0 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import cython
from cython import Py_ssize_t
from cython cimport floating

from libc.stdlib cimport malloc, free

Expand Down Expand Up @@ -382,5 +383,55 @@ def group_any_all(uint8_t[:] out,
out[lab] = flag_val


@cython.wraparound(False)
@cython.boundscheck(False)
def _group_add(floating[:, :] out,
int64_t[:] counts,
floating[:, :] values,
const int64_t[:] labels,
Py_ssize_t min_count=0):
"""
Only aggregates on axis=0
"""
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
floating val, count
ndarray[floating, ndim=2] sumx, nobs

if not len(values) == len(labels):
raise AssertionError("len(index) != len(labels)")

nobs = np.zeros_like(out)
sumx = np.zeros_like(out)

N, K = (<object>values).shape

with nogil:

for i in range(N):
lab = labels[i]
if lab < 0:
continue

counts[lab] += 1
for j in range(K):
val = values[i, j]

# not nan
if val == val:
nobs[lab, j] += 1
sumx[lab, j] += val

for i in range(ncounts):
for j in range(K):
if nobs[i, j] < min_count:
out[i, j] = NAN
else:
out[i, j] = sumx[i, j]


group_add_float32 = _group_add['float']
group_add_float64 = _group_add['double']

# generated from template
include "groupby_helper.pxi"
49 changes: 1 addition & 48 deletions pandas/_libs/groupby_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ cdef extern from "numpy/npy_math.h":
_int64_max = np.iinfo(np.int64).max

# ----------------------------------------------------------------------
# group_add, group_prod, group_var, group_mean, group_ohlc
# group_prod, group_var, group_mean, group_ohlc
# ----------------------------------------------------------------------

{{py:
Expand All @@ -27,53 +27,6 @@ def get_dispatch(dtypes):
{{for name, c_type in get_dispatch(dtypes)}}


@cython.wraparound(False)
@cython.boundscheck(False)
def group_add_{{name}}({{c_type}}[:, :] out,
int64_t[:] counts,
{{c_type}}[:, :] values,
const int64_t[:] labels,
Py_ssize_t min_count=0):
"""
Only aggregates on axis=0
"""
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
{{c_type}} val, count
ndarray[{{c_type}}, ndim=2] sumx, nobs

if not len(values) == len(labels):
raise AssertionError("len(index) != len(labels)")

nobs = np.zeros_like(out)
sumx = np.zeros_like(out)

N, K = (<object>values).shape

with nogil:

for i in range(N):
lab = labels[i]
if lab < 0:
continue

counts[lab] += 1
for j in range(K):
val = values[i, j]

# not nan
if val == val:
nobs[lab, j] += 1
sumx[lab, j] += val

for i in range(ncounts):
for j in range(K):
if nobs[i, j] < min_count:
out[i, j] = NAN
else:
out[i, j] = sumx[i, j]


@cython.wraparound(False)
@cython.boundscheck(False)
def group_prod_{{name}}({{c_type}}[:, :] out,
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def get_func(fname):
# otherwise find dtype-specific version, falling back to object
for dt in [dtype_str, 'object']:
f = getattr(libgroupby, "{fname}_{dtype_str}".format(
fname=fname, dtype_str=dtype_str), None)
fname=fname, dtype_str=dt), None)
if f is not None:
return f

Expand Down

0 comments on commit 2448e52

Please sign in to comment.