From b275fb7012e54d7212ee2ce6b79a19ee79466606 Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Wed, 29 May 2024 16:01:43 +0200 Subject: [PATCH 1/7] add dataclass support to the pprinter function --- paderbox/utils/pretty.py | 42 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/paderbox/utils/pretty.py b/paderbox/utils/pretty.py index 160be31b..49dfe420 100644 --- a/paderbox/utils/pretty.py +++ b/paderbox/utils/pretty.py @@ -95,6 +95,48 @@ def _enumerate(self, seq): yield from super()._enumerate(seq) self.depth -= 1 + @staticmethod + def _dataclass_repr_pretty_(self, p, cycle): + """ + >>> @dataclasses.dataclass + ... class PointClsWithALongName: + ... x: int + ... y: int + >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 8) + PointClsWithALongName( + x=1, + y=2 + ) + >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 9) + PointClsWithALongName(x=1, y=2 + ) + >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 10) + PointClsWithALongName(x=1, y=2) + """ + if cycle: + p.text(f'{self.__class__.__name__}(...)') + else: + txt = f'{self.__class__.__name__}(' + with p.group(4, txt, ''): + keys = self.__dataclass_fields__.keys() + for i, k in enumerate(keys): + if i: + p.breakable(sep=' ') + else: + p.breakable(sep='') + p.text(f'{k}=') + p.pretty(getattr(self, k)) + if i != len(keys) - 1: + p.text(',') + p.breakable('') + p.text(')') + + def _in_deferred_types(self, cls): + if dataclasses.is_dataclass(cls): + return self._dataclass_repr_pretty_ + else: + return super()._in_deferred_types(cls) + def pprint( obj, From 8f098af7149554a693c910cc8a0892d6b4eac114 Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Wed, 29 May 2024 16:22:47 +0200 Subject: [PATCH 2/7] add missing dataclass import --- paderbox/utils/pretty.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paderbox/utils/pretty.py b/paderbox/utils/pretty.py index 49dfe420..b0a2cc58 100644 --- a/paderbox/utils/pretty.py +++ b/paderbox/utils/pretty.py @@ -1,3 +1,4 @@ +import dataclasses import sys import io import IPython.lib.pretty From cab37604c7a822e43c06ee8acca7ba69397e8d9c Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Wed, 29 May 2024 16:55:20 +0200 Subject: [PATCH 3/7] fix pprint to preserve the original --- paderbox/utils/pretty.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/paderbox/utils/pretty.py b/paderbox/utils/pretty.py index b0a2cc58..7b1becca 100644 --- a/paderbox/utils/pretty.py +++ b/paderbox/utils/pretty.py @@ -113,6 +113,16 @@ def _dataclass_repr_pretty_(self, p, cycle): ) >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 10) PointClsWithALongName(x=1, y=2) + + >>> @dataclasses.dataclass + ... class PrettyPoint: + ... x: int + ... y: int + ... def _repr_pretty_(self, p, cycle): + ... p.text(f'CustomRepr(x={self.x}, y={self.y})') + >>> pprint(PrettyPoint(1, 2)) + CustomRepr(x=1, y=2) + """ if cycle: p.text(f'{self.__class__.__name__}(...)') @@ -133,7 +143,7 @@ def _dataclass_repr_pretty_(self, p, cycle): p.text(')') def _in_deferred_types(self, cls): - if dataclasses.is_dataclass(cls): + if '_repr_pretty_' not in cls.__dict__ and dataclasses.is_dataclass(cls): return self._dataclass_repr_pretty_ else: return super()._in_deferred_types(cls) From 23fc17c01fe9500cd58266d2b26f4874f648b713 Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Wed, 29 May 2024 17:21:46 +0200 Subject: [PATCH 4/7] pretty dataclass: Use style from IPython --- paderbox/utils/pretty.py | 40 ++++++++++++++-------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/paderbox/utils/pretty.py b/paderbox/utils/pretty.py index 7b1becca..197d578b 100644 --- a/paderbox/utils/pretty.py +++ b/paderbox/utils/pretty.py @@ -103,14 +103,12 @@ def _dataclass_repr_pretty_(self, p, cycle): ... class PointClsWithALongName: ... x: int ... y: int - >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 8) - PointClsWithALongName( - x=1, - y=2 - ) + >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') - 5) + PointClsWithALongName(x=1, + y=2) >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 9) - PointClsWithALongName(x=1, y=2 - ) + PointClsWithALongName(x=1, + y=2) >>> pprint(PointClsWithALongName(1, 2), max_width=len('PointClsWithALongName') + 10) PointClsWithALongName(x=1, y=2) @@ -124,23 +122,11 @@ def _dataclass_repr_pretty_(self, p, cycle): CustomRepr(x=1, y=2) """ - if cycle: - p.text(f'{self.__class__.__name__}(...)') - else: - txt = f'{self.__class__.__name__}(' - with p.group(4, txt, ''): - keys = self.__dataclass_fields__.keys() - for i, k in enumerate(keys): - if i: - p.breakable(sep=' ') - else: - p.breakable(sep='') - p.text(f'{k}=') - p.pretty(getattr(self, k)) - if i != len(keys) - 1: - p.text(',') - p.breakable('') - p.text(')') + p.pretty(IPython.lib.pretty.CallExpression.factory( + self.__class__.__name__ + )( + **{k: getattr(self, k) for k in self.__dataclass_fields__.keys()} + )) def _in_deferred_types(self, cls): if '_repr_pretty_' not in cls.__dict__ and dataclasses.is_dataclass(cls): @@ -198,6 +184,8 @@ def pprint( >>> print(d.items()) dict_items([('aaaaaaaaaa', 1000000), ('bbbbbbbbbb', 2000000)]) + + """ printer = _MyRepresentationPrinter( @@ -261,7 +249,7 @@ def pretty( import paderbox as pb - def cli_pprint(file, max_seq_length=[10, 5, 2], max_width=None): + def cli_pprint(file, max_seq_length=[10, 5, 2], max_width=None, unsafe=False): """Load a file and pretty print it. With max_seq_length you can control the length of the printed sequences. @@ -271,7 +259,7 @@ def cli_pprint(file, max_seq_length=[10, 5, 2], max_width=None): The last entry is used for all larger depths. max_width: """ - data = pb.io.load(file) + data = pb.io.load(file, unsafe=unsafe) if max_width is None: max_width = shutil.get_terminal_size((79, 20)).columns From 4aabedc84d944ba88e6d472d905480b44249dd0a Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Wed, 29 May 2024 17:23:05 +0200 Subject: [PATCH 5/7] fix doctest --- paderbox/array/sparse.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/paderbox/array/sparse.py b/paderbox/array/sparse.py index 403fb8bd..1adc4e64 100644 --- a/paderbox/array/sparse.py +++ b/paderbox/array/sparse.py @@ -730,12 +730,14 @@ def _repr_pretty_(self, p, cycle): >>> a[:5] = 1 >>> a[7:] = 2 >>> pb.utils.pretty.pprint(a) - SparseArray(_SparseSegment(onset=0, array=array([1., 1., 1., 1., 1.], dtype=float32)), + SparseArray(_SparseSegment(onset=0, + array=array([1., 1., 1., 1., 1.], dtype=float32)), _SparseSegment(onset=7, array=array([2., 2., 2.], dtype=float32)), shape=(10,)) >>> a._pad_value = _get_pad_value(a.dtype, -1) >>> pb.utils.pretty.pprint(a) - SparseArray(_SparseSegment(onset=0, array=array([1., 1., 1., 1., 1.], dtype=float32)), + SparseArray(_SparseSegment(onset=0, + array=array([1., 1., 1., 1., 1.], dtype=float32)), _SparseSegment(onset=7, array=array([2., 2., 2.], dtype=float32)), shape=(10,), pad_value=-1.0) """ From 94d62453404613484d0e919a7c360ec50755f462 Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Wed, 29 May 2024 17:40:51 +0200 Subject: [PATCH 6/7] add backport for py3.7 --- paderbox/utils/pretty.py | 49 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/paderbox/utils/pretty.py b/paderbox/utils/pretty.py index 197d578b..b577245e 100644 --- a/paderbox/utils/pretty.py +++ b/paderbox/utils/pretty.py @@ -4,6 +4,53 @@ import IPython.lib.pretty import numpy as np +if sys.version_info >= (3, 8): + from IPython.lib.pretty import CallExpression +else: + # CallExpression was added in ipython 8, which dropped support for python 3.7 + # The following is a copy of the CallExpression class from IPython 8. + class CallExpression: + """ Object which emits a line-wrapped call expression in the form `__name(*args, **kwargs)` """ + + def __init__(__self, __name, *args, **kwargs): + # dunders are to avoid clashes with kwargs, as python's name manging + # will kick in. + self = __self + self.name = __name + self.args = args + self.kwargs = kwargs + + @classmethod + def factory(cls, name): + def inner(*args, **kwargs): + return cls(name, *args, **kwargs) + + return inner + + def _repr_pretty_(self, p, cycle): + # dunders are to avoid clashes with kwargs, as python's name manging + # will kick in. + + started = False + + def new_item(): + nonlocal started + if started: + p.text(",") + p.breakable() + started = True + + prefix = self.name + "(" + with p.group(len(prefix), prefix, ")"): + for arg in self.args: + new_item() + p.pretty(arg) + for arg_name, arg in self.kwargs.items(): + new_item() + arg_prefix = arg_name + "=" + with p.group(len(arg_prefix), arg_prefix): + p.pretty(arg) + class _MyRepresentationPrinter(IPython.lib.pretty.RepresentationPrinter): def __init__( @@ -122,7 +169,7 @@ def _dataclass_repr_pretty_(self, p, cycle): CustomRepr(x=1, y=2) """ - p.pretty(IPython.lib.pretty.CallExpression.factory( + p.pretty(CallExpression.factory( self.__class__.__name__ )( **{k: getattr(self, k) for k in self.__dataclass_fields__.keys()} From ce3c90f4524ab4f60d49a03c9f963c33ccf96404 Mon Sep 17 00:00:00 2001 From: Christoph Boeddeker Date: Wed, 29 May 2024 17:53:05 +0200 Subject: [PATCH 7/7] fix doctest --- paderbox/visualization/plot.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/paderbox/visualization/plot.py b/paderbox/visualization/plot.py index bf3e0fa7..85fc4c0a 100644 --- a/paderbox/visualization/plot.py +++ b/paderbox/visualization/plot.py @@ -52,6 +52,9 @@ def check_color(f): """ Improve the exception message for color, if color is an int. + Note: Since https://github.com/matplotlib/matplotlib/pull/27905 is merged, + this function is not needed anymore. + >>> fn = check_color(lambda **kwargs: kwargs) >>> fn(color=1) Traceback (most recent call last): @@ -59,10 +62,10 @@ def check_color(f): ValueError: The value of color is an integer. To get the N'th color, you can use f'C{N}', e.g. 'C1'. - >>> plt.plot(np.arange(10), color=1) + >>> plt.plot(np.arange(10), color=1) # doctest: +ELLIPSIS Traceback (most recent call last): ... - ValueError: 1 is not a valid value for color + ValueError: 1 is not a valid value for color... """ @wraps(f)