Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes np.unique on SparseArray #19651

Closed
wants to merge 11 commits into from
45 changes: 32 additions & 13 deletions pandas/_libs/hashtable_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -251,25 +251,39 @@ cdef class HashTable:
{{py:

# name, dtype, null_condition, float_group
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update the comment

dtypes = [('Float64', 'float64', 'val != val', True),
('UInt64', 'uint64', 'False', False),
('Int64', 'int64', 'val == iNaT', False)]
dtypes = [('Float64', 'float64', 'val != val', True, 'NAN'),
('UInt64', 'uint64', 'False', False, 'NAN'),
('Int64', 'int64', 'val == iNaT', False, 'iNaT')]

def get_dispatch(dtypes):
for (name, dtype, null_condition, float_group) in dtypes:
for (name, dtype, null_condition, float_group, na_value) in dtypes:
unique_template = """\
cdef:
Py_ssize_t i, n = len(values)
int ret = 0
{dtype}_t val
{dtype}_t val, fill_value_val, ngaps_val
khiter_t k
bint seen_na = 0
{name}Vector uniques = {name}Vector()
{name}VectorData *ud

ud = uniques.data

fill_value_val = fill_value
ngaps_val = ngaps

with nogil:
# If this is a sparse structure we need to append
# The fill value as well assuming the ngaps are greater than 0

if ngaps_val > 0:
k = kh_get_{dtype}(self.table, fill_value_val)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is duplicated 3 times in this class. Would be nice to de-dupe this somehow without making it super complicated with string interpolation...

if k == self.table.n_buckets:
kh_put_{dtype}(self.table, fill_value_val, &ret)
if needs_resize(ud):
with gil:
uniques.resize()
append_data_{dtype}(ud, fill_value_val)

for i in range(n):
val = values[i]
IF {float_group}:
Expand Down Expand Up @@ -300,11 +314,11 @@ def get_dispatch(dtypes):

unique_template = unique_template.format(name=name, dtype=dtype, null_condition=null_condition, float_group=float_group)

yield (name, dtype, null_condition, float_group, unique_template)
yield (name, dtype, null_condition, float_group, unique_template, na_value)
}}


{{for name, dtype, null_condition, float_group, unique_template in get_dispatch(dtypes)}}
{{for name, dtype, null_condition, float_group, unique_template, na_value in get_dispatch(dtypes)}}

cdef class {{name}}HashTable(HashTable):

Expand Down Expand Up @@ -405,22 +419,27 @@ cdef class {{name}}HashTable(HashTable):
labels = self.get_labels(values, uniques, 0, 0)
return uniques.to_array(), labels

# This seems like duplicate code from def uniques to me...
# Why does this exist?
@cython.boundscheck(False)
def get_labels(self, {{dtype}}_t[:] values, {{name}}Vector uniques,
Py_ssize_t count_prior, Py_ssize_t na_sentinel,
bint check_null=True):
bint check_null=True, fill_value={{na_value}}, ngaps=0):
cdef:
Py_ssize_t i, n = len(values)
int64_t[:] labels
Py_ssize_t idx, count = count_prior
int ret = 0
{{dtype}}_t val
{{dtype}}_t val, fill_value_val, ngaps_val
khiter_t k
{{name}}VectorData *ud

labels = np.empty(n, dtype=np.int64)
ud = uniques.data

if ngaps > 0:
print("Hello world")

with nogil:
for i in range(n):
val = values[i]
Expand Down Expand Up @@ -496,10 +515,10 @@ cdef class {{name}}HashTable(HashTable):
return np.asarray(labels), arr_uniques

@cython.boundscheck(False)
def unique(self, ndarray[{{dtype}}_t, ndim=1] values):
def unique(self, ndarray[{{dtype}}_t, ndim=1] values, fill_value={{na_value}}, ngaps=0):
if values.flags.writeable:
# If the value is writeable (mutable) then use memview
return self.unique_memview(values)
return self.unique_memview(values, fill_value=fill_value, ngaps=ngaps)

# We cannot use the memoryview version on readonly-buffers due to
# a limitation of Cython's typed memoryviews. Instead we can use
Expand All @@ -508,7 +527,7 @@ cdef class {{name}}HashTable(HashTable):
{{unique_template}}

@cython.boundscheck(False)
def unique_memview(self, {{dtype}}_t[:] values):
def unique_memview(self, {{dtype}}_t[:] values, fill_value={{na_value}}, ngaps=0):
{{unique_template}}

{{endfor}}
Expand Down
18 changes: 15 additions & 3 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
maybe_promote, construct_1d_object_array_from_listlike)
from pandas.core.dtypes.generic import (
ABCSeries, ABCIndex,
ABCIndexClass, ABCCategorical)
ABCIndexClass, ABCCategorical,
ABCSparseArray)
from pandas.core.dtypes.common import (
is_unsigned_integer_dtype, is_signed_integer_dtype,
is_integer_dtype, is_complex_dtype,
Expand Down Expand Up @@ -362,7 +363,12 @@ def unique(values):
htable, _, values, dtype, ndtype = _get_hashtable_algo(values)

table = htable(len(values))
uniques = table.unique(values)

if isinstance(values, ABCSparseArray):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure you saw my comment. this should not be handling unique like this. SparseArray should have a method .unique() which can directly handle the unique (and then call algorithms.unique on the sparse values).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IOW you might be able to get aways with doing a

if hasattr(values, 'unique'):
    return values.unique()
...

need to avoid recursion, but here values must be a ndarray or an EA, including Categorical. (and NOT a Series)

cc @TomAugspurger

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hrm I'll have to think about this... I'm hesitant to do this mainly because this code seems to rely on the fact that it always outputs an ndarray, which is why groupby doesn't work with sparse data.

The problem really is with classes that implement their own unique(). It usually outputs something that isn't ndarray.

Also I don't want to call unique on the class and cast it back to ndarray cause that feels sloppy. I'll think of something :) thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hexgnu have you been following the ExtensionArray stuff? Eventually SparseArray should be an ExtensionArray like Categorical. Eventually things like groupby, factorize, etc will all work correctly on EAs.

Also I don't want to call unique on the class and cast it back to ndarray cause that feels sloppy.

SparseArray.unique() should return a SparseArray, just like Categorical.unique returns a Categorical.

uniques = table.unique(values, fill_value=values.fill_value,
ngaps=values.sp_index.ngaps)
else:
uniques = table.unique(values)
uniques = _reconstruct_data(uniques, dtype, original)

if isinstance(original, ABCSeries) and is_datetime64tz_dtype(dtype):
Expand Down Expand Up @@ -469,7 +475,13 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
table = hash_klass(size_hint or len(values))
uniques = vec_klass()
check_nulls = not is_integer_dtype(original)
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls)

if isinstance(values, ABCSparseArray):
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls,
fill_value=values.fill_value,
ngaps=values.sp_index.ngaps)
else:
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls)

labels = _ensure_platform_int(labels)
uniques = uniques.to_array()
Expand Down
1 change: 1 addition & 0 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -3024,6 +3024,7 @@ def is_in_obj(gpr):

# create the Grouping
# allow us to passing the actual Grouping as the gpr

ping = Grouping(group_axis,
gpr,
obj=obj,
Expand Down