Skip to content

Commit

Permalink
BUG/API: .merge() and .join() on category dtype columns will now pres…
Browse files Browse the repository at this point in the history
…erve the category dtype when possible

closes #10409
  • Loading branch information
jreback committed Mar 10, 2017
1 parent 5dee1f1 commit a4b2ee6
Show file tree
Hide file tree
Showing 8 changed files with 288 additions and 71 deletions.
36 changes: 30 additions & 6 deletions asv_bench/benchmarks/join_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pandas import ordered_merge as merge_ordered


#----------------------------------------------------------------------
# ----------------------------------------------------------------------
# Append

class Append(object):
Expand Down Expand Up @@ -35,7 +35,7 @@ def time_append_mixed(self):
self.mdf1.append(self.mdf2)


#----------------------------------------------------------------------
# ----------------------------------------------------------------------
# Concat

class Concat(object):
Expand Down Expand Up @@ -120,7 +120,7 @@ def time_f_ordered_axis1(self):
concat(self.frames_f, axis=1, ignore_index=True)


#----------------------------------------------------------------------
# ----------------------------------------------------------------------
# Joins

class Join(object):
Expand Down Expand Up @@ -202,7 +202,7 @@ def time_join_non_unique_equal(self):
(self.fracofday * self.temp[self.fracofday.index])


#----------------------------------------------------------------------
# ----------------------------------------------------------------------
# Merges

class Merge(object):
Expand Down Expand Up @@ -257,7 +257,31 @@ def time_i8merge(self):
merge(self.left, self.right, how='outer')


#----------------------------------------------------------------------
class MergeCategoricals(object):
goal_time = 0.2

def setup(self):
self.left_object = pd.DataFrame(
{'X': np.random.choice(range(0, 10), size=(10000,)),
'Y': np.random.choice(['one', 'two', 'three'], size=(10000,))})

self.right_object = pd.DataFrame(
{'X': np.random.choice(range(0, 10), size=(10000,)),
'Z': np.random.choice(['jjj', 'kkk', 'sss'], size=(10000,))})

self.left_cat = self.left_object.assign(
Y=self.left_object['Y'].astype('category'))
self.right_cat = self.right_object.assign(
Z=self.right_object['Z'].astype('category'))

def time_merge_object(self):
merge(self.left_object, self.right_object, on='X')

def time_merge_cat(self):
merge(self.left_cat, self.right_cat, on='X')


# ----------------------------------------------------------------------
# Ordered merge

class MergeOrdered(object):
Expand Down Expand Up @@ -332,7 +356,7 @@ def time_multiby(self):
merge_asof(self.df1e, self.df2e, on='time', by=['key', 'key2'])


#----------------------------------------------------------------------
# ----------------------------------------------------------------------
# data alignment

class Align(object):
Expand Down
4 changes: 3 additions & 1 deletion doc/source/whatsnew/v0.20.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ Other API Changes
- Reorganization of timeseries development tests (:issue:`14854`)
- Specific support for ``copy.copy()`` and ``copy.deepcopy()`` functions on NDFrame objects (:issue:`15444`)
- ``Series.sort_values()`` accepts a one element list of bool for consistency with the behavior of ``DataFrame.sort_values()`` (:issue:`15604`)
- ``DataFrame.iterkv()`` has been removed in favor of ``DataFrame.iteritems()`` (:issue:`10711`)
- ``.merge()`` and ``.join()`` on ``category`` dtype columns will now preserve the category dtype when possible (:issue:`10409`)

.. _whatsnew_0200.deprecations:

Expand Down Expand Up @@ -733,6 +733,7 @@ Removal of prior version deprecations/changes
- ``Series.is_time_series`` is dropped in favor of ``Series.index.is_all_dates`` (:issue:`15098`)
- The deprecated ``irow``, ``icol``, ``iget`` and ``iget_value`` methods are removed
in favor of ``iloc`` and ``iat`` as explained :ref:`here <whatsnew_0170.deprecations>` (:issue:`10711`).
- The deprecated ``DataFrame.iterkv()`` has been removed in favor of ``DataFrame.iteritems()`` (:issue:`10711`)


.. _whatsnew_0200.performance:
Expand All @@ -749,6 +750,7 @@ Performance Improvements
- When reading buffer object in ``read_sas()`` method without specified format, filepath string is inferred rather than buffer object. (:issue:`14947`)
- Improved performance of ``.rank()`` for categorical data (:issue:`15498`)
- Improved performance when using ``.unstack()`` (:issue:`15503`)
- Improved performance of merge/join on ``category`` columns (:issue:`10409`)


.. _whatsnew_0200.bug_fixes:
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -5227,6 +5227,8 @@ def get_reindexed_values(self, empty_dtype, upcasted_na):
# External code requested filling/upcasting, bool values must
# be upcasted to object to avoid being upcasted to numeric.
values = self.block.astype(np.object_).values
elif self.block.is_categorical:
values = self.block.values
else:
# No dtype upcasting is done here, it will be performed during
# concatenation itself.
Expand Down
3 changes: 3 additions & 0 deletions pandas/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -4097,9 +4097,12 @@ def test_merge(self):
expected = df.copy()

# object-cat
# note that we propogate the category
# because we don't have any matching rows
cright = right.copy()
cright['d'] = cright['d'].astype('category')
result = pd.merge(left, cright, how='left', left_on='b', right_on='c')
expected['d'] = expected['d'].astype('category', categories=['null'])
tm.assert_frame_equal(result, expected)

# cat-object
Expand Down
177 changes: 145 additions & 32 deletions pandas/tests/tools/test_merge.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=E1103

import pytest
from datetime import datetime
from numpy.random import randn
from numpy import nan
Expand All @@ -11,6 +12,8 @@
from pandas.tools.concat import concat
from pandas.tools.merge import merge, MergeError
from pandas.util.testing import assert_frame_equal, assert_series_equal
from pandas.types.dtypes import CategoricalDtype
from pandas.types.common import is_categorical_dtype, is_object_dtype
from pandas import DataFrame, Index, MultiIndex, Series, Categorical
import pandas.util.testing as tm

Expand Down Expand Up @@ -1024,38 +1027,6 @@ def test_left_join_index_multi_match(self):
expected.index = np.arange(len(expected))
tm.assert_frame_equal(result, expected)

def test_join_multi_dtypes(self):

# test with multi dtypes in the join index
def _test(dtype1, dtype2):
left = DataFrame({'k1': np.array([0, 1, 2] * 8, dtype=dtype1),
'k2': ['foo', 'bar'] * 12,
'v': np.array(np.arange(24), dtype=np.int64)})

index = MultiIndex.from_tuples([(2, 'bar'), (1, 'foo')])
right = DataFrame(
{'v2': np.array([5, 7], dtype=dtype2)}, index=index)

result = left.join(right, on=['k1', 'k2'])

expected = left.copy()

if dtype2.kind == 'i':
dtype2 = np.dtype('float64')
expected['v2'] = np.array(np.nan, dtype=dtype2)
expected.loc[(expected.k1 == 2) & (expected.k2 == 'bar'), 'v2'] = 5
expected.loc[(expected.k1 == 1) & (expected.k2 == 'foo'), 'v2'] = 7

tm.assert_frame_equal(result, expected)

result = left.join(right, on=['k1', 'k2'], sort=True)
expected.sort_values(['k1', 'k2'], kind='mergesort', inplace=True)
tm.assert_frame_equal(result, expected)

for d1 in [np.int64, np.int32, np.int16, np.int8, np.uint8]:
for d2 in [np.int64, np.float64, np.float32, np.float16]:
_test(np.dtype(d1), np.dtype(d2))

def test_left_merge_na_buglet(self):
left = DataFrame({'id': list('abcde'), 'v1': randn(5),
'v2': randn(5), 'dummy': list('abcde'),
Expand Down Expand Up @@ -1242,3 +1213,145 @@ def f():
def f():
household.join(log_return, how='outer')
self.assertRaises(NotImplementedError, f)


@pytest.fixture
def df():
return DataFrame(
{'A': ['foo', 'bar'],
'B': Series(['foo', 'bar']).astype('category'),
'C': [1, 2],
'D': [1.0, 2.0],
'E': Series([1, 2], dtype='uint64'),
'F': Series([1, 2], dtype='int32')})


class TestMergeDtypes(object):

def test_different(self, df):

# we expect differences by kind
# to be ok, while other differences should return object

left = df
for col in df.columns:
right = DataFrame({'A': df[col]})
result = pd.merge(left, right, on='A')
assert is_object_dtype(result.A.dtype)

@pytest.mark.parametrize('d1', [np.int64, np.int32,
np.int16, np.int8, np.uint8])
@pytest.mark.parametrize('d2', [np.int64, np.float64,
np.float32, np.float16])
def test_join_multi_dtypes(self, d1, d2):

dtype1 = np.dtype(d1)
dtype2 = np.dtype(d2)

left = DataFrame({'k1': np.array([0, 1, 2] * 8, dtype=dtype1),
'k2': ['foo', 'bar'] * 12,
'v': np.array(np.arange(24), dtype=np.int64)})

index = MultiIndex.from_tuples([(2, 'bar'), (1, 'foo')])
right = DataFrame({'v2': np.array([5, 7], dtype=dtype2)}, index=index)

result = left.join(right, on=['k1', 'k2'])

expected = left.copy()

if dtype2.kind == 'i':
dtype2 = np.dtype('float64')
expected['v2'] = np.array(np.nan, dtype=dtype2)
expected.loc[(expected.k1 == 2) & (expected.k2 == 'bar'), 'v2'] = 5
expected.loc[(expected.k1 == 1) & (expected.k2 == 'foo'), 'v2'] = 7

tm.assert_frame_equal(result, expected)

result = left.join(right, on=['k1', 'k2'], sort=True)
expected.sort_values(['k1', 'k2'], kind='mergesort', inplace=True)
tm.assert_frame_equal(result, expected)


@pytest.fixture
def left():
np.random.seed(1234)
return DataFrame(
{'X': Series(np.random.choice(
['foo', 'bar'],
size=(10,))).astype('category', categories=['foo', 'bar']),
'Y': np.random.choice(['one', 'two', 'three'], size=(10,))})


@pytest.fixture
def right():
np.random.seed(1234)
return DataFrame(
{'X': Series(['foo', 'bar']).astype('category',
categories=['foo', 'bar']),
'Z': [1, 2]})


class TestMergeCategorical(object):

def test_identical(self, left):
# merging on the same, should preserve dtypes
merged = pd.merge(left, left, on='X')
result = merged.dtypes.sort_index()
expected = Series([CategoricalDtype(),
np.dtype('O'),
np.dtype('O')],
index=['X', 'Y_x', 'Y_y'])
assert_series_equal(result, expected)

def test_basic(self, left, right):
# we have matching Categorical dtypes in X
# so should preserve the merged column
merged = pd.merge(left, right, on='X')
result = merged.dtypes.sort_index()
expected = Series([CategoricalDtype(),
np.dtype('O'),
np.dtype('int64')],
index=['X', 'Y', 'Z'])
assert_series_equal(result, expected)

def test_other_columns(self, left, right):
# non-merge columns should preserve if possible
right = right.assign(Z=right.Z.astype('category'))

merged = pd.merge(left, right, on='X')
result = merged.dtypes.sort_index()
expected = Series([CategoricalDtype(),
np.dtype('O'),
CategoricalDtype()],
index=['X', 'Y', 'Z'])
assert_series_equal(result, expected)

# categories are preserved
assert left.X.values.is_dtype_equal(merged.X.values)
assert right.Z.values.is_dtype_equal(merged.Z.values)

@pytest.mark.parametrize(
'change', [lambda x: x,
lambda x: x.astype('category',
categories=['bar', 'foo']),
lambda x: x.astype('category',
categories=['foo', 'bar', 'bah']),
lambda x: x.astype('category', ordered=True)])
@pytest.mark.parametrize('how', ['inner', 'outer', 'left', 'right'])
def test_dtype_on_merged_different(self, change, how, left, right):
# our merging columns, X now has 2 different dtypes
# so we must be object as a result

X = change(right.X.astype('object'))
right = right.assign(X=X)
assert is_categorical_dtype(left.X.values)
assert not left.X.values.is_dtype_equal(right.X.values)

merged = pd.merge(left, right, on='X', how=how)

result = merged.dtypes.sort_index()
expected = Series([np.dtype('O'),
np.dtype('O'),
np.dtype('int64')],
index=['X', 'Y', 'Z'])
assert_series_equal(result, expected)
1 change: 1 addition & 0 deletions pandas/tests/tools/test_merge_asof.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def test_basic_categorical(self):
trades.ticker = trades.ticker.astype('category')
quotes = self.quotes.copy()
quotes.ticker = quotes.ticker.astype('category')
expected.ticker = expected.ticker.astype('category')

result = merge_asof(trades, quotes,
on='time',
Expand Down
50 changes: 39 additions & 11 deletions pandas/tests/types/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-

import pytest
import numpy as np

from pandas.types.dtypes import DatetimeTZDtype, PeriodDtype, CategoricalDtype
Expand Down Expand Up @@ -38,17 +39,44 @@ def test_period_dtype(self):
self.assertEqual(pandas_dtype(dtype), dtype)


def test_dtype_equal():
assert is_dtype_equal(np.int64, np.int64)
assert not is_dtype_equal(np.int64, np.float64)
dtypes = dict(datetime_tz=pandas_dtype('datetime64[ns, US/Eastern]'),
datetime=pandas_dtype('datetime64[ns]'),
timedelta=pandas_dtype('timedelta64[ns]'),
period=PeriodDtype('D'),
integer=np.dtype(np.int64),
float=np.dtype(np.float64),
object=np.dtype(np.object),
category=pandas_dtype('category'))

p1 = PeriodDtype('D')
p2 = PeriodDtype('D')
assert is_dtype_equal(p1, p2)
assert not is_dtype_equal(np.int64, p1)

p3 = PeriodDtype('2D')
assert not is_dtype_equal(p1, p3)
@pytest.mark.parametrize('name1,dtype1',
list(dtypes.items()),
ids=lambda x: str(x))
@pytest.mark.parametrize('name2,dtype2',
list(dtypes.items()),
ids=lambda x: str(x))
def test_dtype_equal(name1, dtype1, name2, dtype2):

assert not DatetimeTZDtype.is_dtype(np.int64)
assert not PeriodDtype.is_dtype(np.int64)
# match equal to self, but not equal to other
assert is_dtype_equal(dtype1, dtype1)
if name1 != name2:
assert not is_dtype_equal(dtype1, dtype2)


def test_dtype_equal_strict():

# we are strict on kind equality
for dtype in [np.int8, np.int16, np.int32]:
assert not is_dtype_equal(np.int64, dtype)

for dtype in [np.float32]:
assert not is_dtype_equal(np.float64, dtype)

# strict w.r.t. PeriodDtype
assert not is_dtype_equal(PeriodDtype('D'),
PeriodDtype('2D'))

# strict w.r.t. datetime64
assert not is_dtype_equal(
pandas_dtype('datetime64[ns, US/Eastern]'),
pandas_dtype('datetime64[ns, CET]'))
Loading

0 comments on commit a4b2ee6

Please sign in to comment.