Skip to content

Commit

Permalink
TST: test custom _formatter for ExtensionArray + revert ExtensionArra…
Browse files Browse the repository at this point in the history
…yFormatter removal (#26845)

* TST: test custom _formatter for ExtensionArray

* Revert "REF: remove ExtensionArrayFormatter (#26833)"

This reverts commit a00659a.
  • Loading branch information
jorisvandenbossche authored and jreback committed Jun 14, 2019
1 parent 137a886 commit 7ecfa8e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 16 deletions.
8 changes: 5 additions & 3 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,9 +1061,11 @@ def _format_with_header(self, header, **kwargs):

def _format_native_types(self, na_rep='NaN', quoting=None, **kwargs):
""" actually format my specific types """
from pandas.io.formats.format import format_array
return format_array(values=self, na_rep=na_rep, justify='all',
leading_space=False)
from pandas.io.formats.format import ExtensionArrayFormatter
return ExtensionArrayFormatter(values=self,
na_rep=na_rep,
justify='all',
leading_space=False).get_result()

def _format_data(self, name=None):

Expand Down
38 changes: 26 additions & 12 deletions pandas/io/formats/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,7 @@ def _get_column_name_list(self):
# Array formatters


def format_array(values, formatter=None, float_format=None, na_rep='NaN',
def format_array(values, formatter, float_format=None, na_rep='NaN',
digits=None, space=None, justify='right', decimal='.',
leading_space=None):
"""
Expand Down Expand Up @@ -879,23 +879,14 @@ def format_array(values, formatter=None, float_format=None, na_rep='NaN',
List[str]
"""

if is_extension_array_dtype(values.dtype):
if isinstance(values, (ABCIndexClass, ABCSeries)):
values = values._values

if is_categorical_dtype(values.dtype):
# Categorical is special for now, so that we can preserve tzinfo
values = values.get_values()

if not is_datetime64tz_dtype(values.dtype):
values = np.asarray(values)

if is_datetime64_dtype(values.dtype):
fmt_klass = Datetime64Formatter
elif is_datetime64tz_dtype(values):
fmt_klass = Datetime64TZFormatter
elif is_timedelta64_dtype(values.dtype):
fmt_klass = Timedelta64Formatter
elif is_extension_array_dtype(values.dtype):
fmt_klass = ExtensionArrayFormatter
elif is_float_dtype(values.dtype) or is_complex_dtype(values.dtype):
fmt_klass = FloatArrayFormatter
elif is_integer_dtype(values.dtype):
Expand Down Expand Up @@ -1190,6 +1181,29 @@ def _format_strings(self):
return fmt_values.tolist()


class ExtensionArrayFormatter(GenericArrayFormatter):
def _format_strings(self):
values = self.values
if isinstance(values, (ABCIndexClass, ABCSeries)):
values = values._values

formatter = values._formatter(boxed=True)

if is_categorical_dtype(values.dtype):
# Categorical is special for now, so that we can preserve tzinfo
array = values.get_values()
else:
array = np.asarray(values)

fmt_values = format_array(array,
formatter,
float_format=self.float_format,
na_rep=self.na_rep, digits=self.digits,
space=self.space, justify=self.justify,
leading_space=self.leading_space)
return fmt_values


def format_percentiles(percentiles):
"""
Outputs rounded and formatted percentiles.
Expand Down
5 changes: 5 additions & 0 deletions pandas/tests/extension/decimal/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def isna(self):
def _na_value(self):
return decimal.Decimal('NaN')

def _formatter(self, boxed=False):
if boxed:
return "Decimal: {0}".format
return repr

@classmethod
def _concat_same_type(cls, to_concat):
return cls(np.concatenate([x._data for x in to_concat]))
Expand Down
8 changes: 7 additions & 1 deletion pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,13 @@ class TestSetitem(BaseDecimal, base.BaseSetitemTests):


class TestPrinting(BaseDecimal, base.BasePrintingTests):
pass

def test_series_repr(self, data):
# Overriding this base test to explicitly test that
# the custom _formatter is used
ser = pd.Series(data)
assert data.dtype.name in repr(ser)
assert "Decimal: " in repr(ser)


# TODO(extension)
Expand Down

0 comments on commit 7ecfa8e

Please sign in to comment.