-
-
Notifications
You must be signed in to change notification settings - Fork 18k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixes np.unique on SparseArray #19651
Changes from 2 commits
b4a6359
90a0b3c
911c265
d9d643b
6666cd6
77e6754
6f242ee
5b39aa2
af7a804
a716ffb
4a13a75
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -251,25 +251,39 @@ cdef class HashTable: | |
{{py: | ||
|
||
# name, dtype, null_condition, float_group | ||
dtypes = [('Float64', 'float64', 'val != val', True), | ||
('UInt64', 'uint64', 'False', False), | ||
('Int64', 'int64', 'val == iNaT', False)] | ||
dtypes = [('Float64', 'float64', 'val != val', True, 'NAN'), | ||
('UInt64', 'uint64', 'False', False, 'NAN'), | ||
('Int64', 'int64', 'val == iNaT', False, 'iNaT')] | ||
|
||
def get_dispatch(dtypes): | ||
for (name, dtype, null_condition, float_group) in dtypes: | ||
for (name, dtype, null_condition, float_group, na_value) in dtypes: | ||
unique_template = """\ | ||
cdef: | ||
Py_ssize_t i, n = len(values) | ||
int ret = 0 | ||
{dtype}_t val | ||
{dtype}_t val, fill_value_val, ngaps_val | ||
khiter_t k | ||
bint seen_na = 0 | ||
{name}Vector uniques = {name}Vector() | ||
{name}VectorData *ud | ||
|
||
ud = uniques.data | ||
|
||
fill_value_val = fill_value | ||
ngaps_val = ngaps | ||
|
||
with nogil: | ||
# If this is a sparse structure we need to append | ||
# The fill value as well assuming the ngaps are greater than 0 | ||
|
||
if ngaps_val > 0: | ||
k = kh_get_{dtype}(self.table, fill_value_val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is duplicated 3 times in this class. Would be nice to de-dupe this somehow without making it super complicated with string interpolation... |
||
if k == self.table.n_buckets: | ||
kh_put_{dtype}(self.table, fill_value_val, &ret) | ||
if needs_resize(ud): | ||
with gil: | ||
uniques.resize() | ||
append_data_{dtype}(ud, fill_value_val) | ||
|
||
for i in range(n): | ||
val = values[i] | ||
IF {float_group}: | ||
|
@@ -300,11 +314,11 @@ def get_dispatch(dtypes): | |
|
||
unique_template = unique_template.format(name=name, dtype=dtype, null_condition=null_condition, float_group=float_group) | ||
|
||
yield (name, dtype, null_condition, float_group, unique_template) | ||
yield (name, dtype, null_condition, float_group, unique_template, na_value) | ||
}} | ||
|
||
|
||
{{for name, dtype, null_condition, float_group, unique_template in get_dispatch(dtypes)}} | ||
{{for name, dtype, null_condition, float_group, unique_template, na_value in get_dispatch(dtypes)}} | ||
|
||
cdef class {{name}}HashTable(HashTable): | ||
|
||
|
@@ -405,22 +419,27 @@ cdef class {{name}}HashTable(HashTable): | |
labels = self.get_labels(values, uniques, 0, 0) | ||
return uniques.to_array(), labels | ||
|
||
# This seems like duplicate code from def uniques to me... | ||
# Why does this exist? | ||
@cython.boundscheck(False) | ||
def get_labels(self, {{dtype}}_t[:] values, {{name}}Vector uniques, | ||
Py_ssize_t count_prior, Py_ssize_t na_sentinel, | ||
bint check_null=True): | ||
bint check_null=True, fill_value={{na_value}}, ngaps=0): | ||
cdef: | ||
Py_ssize_t i, n = len(values) | ||
int64_t[:] labels | ||
Py_ssize_t idx, count = count_prior | ||
int ret = 0 | ||
{{dtype}}_t val | ||
{{dtype}}_t val, fill_value_val, ngaps_val | ||
khiter_t k | ||
{{name}}VectorData *ud | ||
|
||
labels = np.empty(n, dtype=np.int64) | ||
ud = uniques.data | ||
|
||
if ngaps > 0: | ||
print("Hello world") | ||
|
||
with nogil: | ||
for i in range(n): | ||
val = values[i] | ||
|
@@ -496,10 +515,10 @@ cdef class {{name}}HashTable(HashTable): | |
return np.asarray(labels), arr_uniques | ||
|
||
@cython.boundscheck(False) | ||
def unique(self, ndarray[{{dtype}}_t, ndim=1] values): | ||
def unique(self, ndarray[{{dtype}}_t, ndim=1] values, fill_value={{na_value}}, ngaps=0): | ||
if values.flags.writeable: | ||
# If the value is writeable (mutable) then use memview | ||
return self.unique_memview(values) | ||
return self.unique_memview(values, fill_value=fill_value, ngaps=ngaps) | ||
|
||
# We cannot use the memoryview version on readonly-buffers due to | ||
# a limitation of Cython's typed memoryviews. Instead we can use | ||
|
@@ -508,7 +527,7 @@ cdef class {{name}}HashTable(HashTable): | |
{{unique_template}} | ||
|
||
@cython.boundscheck(False) | ||
def unique_memview(self, {{dtype}}_t[:] values): | ||
def unique_memview(self, {{dtype}}_t[:] values, fill_value={{na_value}}, ngaps=0): | ||
{{unique_template}} | ||
|
||
{{endfor}} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,8 @@ | |
maybe_promote, construct_1d_object_array_from_listlike) | ||
from pandas.core.dtypes.generic import ( | ||
ABCSeries, ABCIndex, | ||
ABCIndexClass, ABCCategorical) | ||
ABCIndexClass, ABCCategorical, | ||
ABCSparseArray) | ||
from pandas.core.dtypes.common import ( | ||
is_unsigned_integer_dtype, is_signed_integer_dtype, | ||
is_integer_dtype, is_complex_dtype, | ||
|
@@ -362,7 +363,12 @@ def unique(values): | |
htable, _, values, dtype, ndtype = _get_hashtable_algo(values) | ||
|
||
table = htable(len(values)) | ||
uniques = table.unique(values) | ||
|
||
if isinstance(values, ABCSparseArray): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not sure you saw my comment. this should not be handling unique like this. SparseArray should have a method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IOW you might be able to get aways with doing a
need to avoid recursion, but here values must be a ndarray or an EA, including Categorical. (and NOT a Series) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hrm I'll have to think about this... I'm hesitant to do this mainly because this code seems to rely on the fact that it always outputs an ndarray, which is why groupby doesn't work with sparse data. The problem really is with classes that implement their own unique(). It usually outputs something that isn't ndarray. Also I don't want to call unique on the class and cast it back to ndarray cause that feels sloppy. I'll think of something :) thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @hexgnu have you been following the ExtensionArray stuff? Eventually
SparseArray.unique() should return a SparseArray, just like Categorical.unique returns a Categorical. |
||
uniques = table.unique(values, fill_value=values.fill_value, | ||
ngaps=values.sp_index.ngaps) | ||
else: | ||
uniques = table.unique(values) | ||
uniques = _reconstruct_data(uniques, dtype, original) | ||
|
||
if isinstance(original, ABCSeries) and is_datetime64tz_dtype(dtype): | ||
|
@@ -469,7 +475,13 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None): | |
table = hash_klass(size_hint or len(values)) | ||
uniques = vec_klass() | ||
check_nulls = not is_integer_dtype(original) | ||
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls) | ||
|
||
if isinstance(values, ABCSparseArray): | ||
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls, | ||
fill_value=values.fill_value, | ||
ngaps=values.sp_index.ngaps) | ||
else: | ||
labels = table.get_labels(values, uniques, 0, na_sentinel, check_nulls) | ||
|
||
labels = _ensure_platform_int(labels) | ||
uniques = uniques.to_array() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update the comment