diff --git a/paderbox/array/sparse.py b/paderbox/array/sparse.py index 403fb8bd..1adc4e64 100644 --- a/paderbox/array/sparse.py +++ b/paderbox/array/sparse.py @@ -730,12 +730,14 @@ def _repr_pretty_(self, p, cycle): >>> a[:5] = 1 >>> a[7:] = 2 >>> pb.utils.pretty.pprint(a) - SparseArray(_SparseSegment(onset=0, array=array([1., 1., 1., 1., 1.], dtype=float32)), + SparseArray(_SparseSegment(onset=0, + array=array([1., 1., 1., 1., 1.], dtype=float32)), _SparseSegment(onset=7, array=array([2., 2., 2.], dtype=float32)), shape=(10,)) >>> a._pad_value = _get_pad_value(a.dtype, -1) >>> pb.utils.pretty.pprint(a) - SparseArray(_SparseSegment(onset=0, array=array([1., 1., 1., 1., 1.], dtype=float32)), + SparseArray(_SparseSegment(onset=0, + array=array([1., 1., 1., 1., 1.], dtype=float32)), _SparseSegment(onset=7, array=array([2., 2., 2.], dtype=float32)), shape=(10,), pad_value=-1.0) """ diff --git a/paderbox/utils/pretty.py b/paderbox/utils/pretty.py index 160be31b..b577245e 100644 --- a/paderbox/utils/pretty.py +++ b/paderbox/utils/pretty.py @@ -1,8 +1,56 @@ +import dataclasses import sys import io import IPython.lib.pretty import numpy as np +if sys.version_info >= (3, 8): + from IPython.lib.pretty import CallExpression +else: + # CallExpression was added in ipython 8, which dropped support for python 3.7 + # The following is a copy of the CallExpression class from IPython 8. + class CallExpression: + """ Object which emits a line-wrapped call expression in the form `__name(*args, **kwargs)` """ + + def __init__(__self, __name, *args, **kwargs): + # dunders are to avoid clashes with kwargs, as python's name manging + # will kick in. + self = __self + self.name = __name + self.args = args + self.kwargs = kwargs + + @classmethod + def factory(cls, name): + def inner(*args, **kwargs): + return cls(name, *args, **kwargs) + + return inner + + def _repr_pretty_(self, p, cycle): + # dunders are to avoid clashes with kwargs, as python's name manging + # will kick in. + + started = False + + def new_item(): + nonlocal started + if started: + p.text(",") + p.breakable() + started = True + + prefix = self.name + "(" + with p.group(len(prefix), prefix, ")"): + for arg in self.args: + new_item() + p.pretty(arg) + for arg_name, arg in self.kwargs.items(): + new_item() + arg_prefix = arg_name + "=" + with p.group(len(arg_prefix), arg_prefix): + p.pretty(arg) + class _MyRepresentationPrinter(IPython.lib.pretty.RepresentationPrinter): def __init__( @@ -95,6 +143,44 @@ def _enumerate(self, seq): yield from super()._enumerate(seq) self.depth -= 1 + @staticmethod + def _dataclass_repr_pretty_(self, p, cycle): + """ + >>> @dataclasses.dataclass + ... class PointClsWithALongName: + ... x: int + ... y: int + >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') - 5) + PointClsWithALongName(x=1, + y=2) + >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 9) + PointClsWithALongName(x=1, + y=2) + >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 10) + PointClsWithALongName(x=1, y=2) + + >>> @dataclasses.dataclass + ... class PrettyPoint: + ... x: int + ... y: int + ... def _repr_pretty_(self, p, cycle): + ... p.text(f'CustomRepr(x={self.x}, y={self.y})') + >>> pprint(PrettyPoint(1, 2)) + CustomRepr(x=1, y=2) + + """ + p.pretty(CallExpression.factory( + self.__class__.__name__ + )( + **{k: getattr(self, k) for k in self.__dataclass_fields__.keys()} + )) + + def _in_deferred_types(self, cls): + if '_repr_pretty_' not in cls.__dict__ and dataclasses.is_dataclass(cls): + return self._dataclass_repr_pretty_ + else: + return super()._in_deferred_types(cls) + def pprint( obj, @@ -145,6 +231,8 @@ def pprint( >>> print(d.items()) dict_items([('aaaaaaaaaa', 1000000), ('bbbbbbbbbb', 2000000)]) + + """ printer = _MyRepresentationPrinter( @@ -208,7 +296,7 @@ def pretty( import paderbox as pb - def cli_pprint(file, max_seq_length=[10, 5, 2], max_width=None): + def cli_pprint(file, max_seq_length=[10, 5, 2], max_width=None, unsafe=False): """Load a file and pretty print it. With max_seq_length you can control the length of the printed sequences. @@ -218,7 +306,7 @@ def cli_pprint(file, max_seq_length=[10, 5, 2], max_width=None): The last entry is used for all larger depths. max_width: """ - data = pb.io.load(file) + data = pb.io.load(file, unsafe=unsafe) if max_width is None: max_width = shutil.get_terminal_size((79, 20)).columns diff --git a/paderbox/visualization/plot.py b/paderbox/visualization/plot.py index bf3e0fa7..85fc4c0a 100644 --- a/paderbox/visualization/plot.py +++ b/paderbox/visualization/plot.py @@ -52,6 +52,9 @@ def check_color(f): """ Improve the exception message for color, if color is an int. + Note: Since https://github.com/matplotlib/matplotlib/pull/27905 is merged, + this function is not needed anymore. + >>> fn = check_color(lambda **kwargs: kwargs) >>> fn(color=1) Traceback (most recent call last): @@ -59,10 +62,10 @@ def check_color(f): ValueError: The value of color is an integer. To get the N'th color, you can use f'C{N}', e.g. 'C1'. - >>> plt.plot(np.arange(10), color=1) + >>> plt.plot(np.arange(10), color=1) # doctest: +ELLIPSIS Traceback (most recent call last): ... - ValueError: 1 is not a valid value for color + ValueError: 1 is not a valid value for color... """ @wraps(f)