Skip to content

Commit

Permalink
BUG: Handle all-NA blocks in concat (#20382)
Browse files Browse the repository at this point in the history
* BUG: Handle all-NA blocks in concat

Previously we special cased all-na blocks. We should only do that for
non-extension dtypes.
  • Loading branch information
TomAugspurger authored Mar 18, 2018
1 parent 670c2e4 commit aaa4d90
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 21 deletions.
3 changes: 2 additions & 1 deletion pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5391,7 +5391,8 @@ def is_uniform_join_units(join_units):
# all blocks need to have the same type
all(type(ju.block) is type(join_units[0].block) for ju in join_units) and # noqa
# no blocks that would get missing values (can lead to type upcasts)
all(not ju.is_na for ju in join_units) and
# unless we're an extension dtype.
all(not ju.is_na or ju.block.is_extension for ju in join_units) and
# no blocks with indexers (as then the dimensions do not fit)
all(not ju.indexers for ju in join_units) and
# disregard Panels
Expand Down
15 changes: 15 additions & 0 deletions pandas/tests/extension/base/reshaping.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,21 @@ def test_concat(self, data, in_frame):
assert dtype == data.dtype
assert isinstance(result._data.blocks[0], ExtensionBlock)

@pytest.mark.parametrize('in_frame', [True, False])
def test_concat_all_na_block(self, data_missing, in_frame):
valid_block = pd.Series(data_missing.take([1, 1]), index=[0, 1])
na_block = pd.Series(data_missing.take([0, 0]), index=[2, 3])
if in_frame:
valid_block = pd.DataFrame({"a": valid_block})
na_block = pd.DataFrame({"a": na_block})
result = pd.concat([valid_block, na_block])
if in_frame:
expected = pd.DataFrame({"a": data_missing.take([1, 1, 0, 0])})
self.assert_frame_equal(result, expected)
else:
expected = pd.Series(data_missing.take([1, 1, 0, 0]))
self.assert_series_equal(result, expected)

def test_align(self, data, na_value):
a = data[:3]
b = data[2:5]
Expand Down
31 changes: 11 additions & 20 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,22 @@ def na_value():


class BaseDecimal(object):
@staticmethod
def assert_series_equal(left, right, *args, **kwargs):
# tm.assert_series_equal doesn't handle Decimal('NaN').
# We will ensure that the NA values match, and then
# drop those values before moving on.

def assert_series_equal(self, left, right, *args, **kwargs):

left_na = left.isna()
right_na = right.isna()

tm.assert_series_equal(left_na, right_na)
tm.assert_series_equal(left[~left_na], right[~right_na],
*args, **kwargs)

@staticmethod
def assert_frame_equal(left, right, *args, **kwargs):
# TODO(EA): select_dtypes
decimals = (left.dtypes == 'decimal').index

for col in decimals:
BaseDecimal.assert_series_equal(left[col], right[col],
*args, **kwargs)

left = left.drop(columns=decimals)
right = right.drop(columns=decimals)
tm.assert_frame_equal(left, right, *args, **kwargs)
return tm.assert_series_equal(left[~left_na],
right[~right_na],
*args, **kwargs)

def assert_frame_equal(self, left, right, *args, **kwargs):
self.assert_series_equal(left.dtypes, right.dtypes)
for col in left.columns:
self.assert_series_equal(left[col], right[col],
*args, **kwargs)


class TestDtype(BaseDecimal, base.BaseDtypeTests):
Expand Down

0 comments on commit aaa4d90

Please sign in to comment.