From 9ee586d26f4db7693e71ca539bc8c4337a8ff84e Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Tue, 7 Nov 2023 21:43:47 +0000 Subject: [PATCH] feat: add string formatter to `ak.Array.show` (#2803) * wip: add hints * feat: add `formatter` and `precision` options to string formatting * test: ensure all options work! --- src/awkward/_prettyprint.py | 161 +++++++++++++++++++++++----- src/awkward/highlevel.py | 79 ++++++++++++-- tests/test_2803_string_formatter.py | 143 ++++++++++++++++++++++++ 3 files changed, 349 insertions(+), 34 deletions(-) create mode 100644 tests/test_2803_string_formatter.py diff --git a/src/awkward/_prettyprint.py b/src/awkward/_prettyprint.py index cdfb079ad4..2e09ac0b04 100644 --- a/src/awkward/_prettyprint.py +++ b/src/awkward/_prettyprint.py @@ -3,21 +3,50 @@ from __future__ import annotations import math -import numbers import re +from collections.abc import Callable import awkward as ak from awkward._layout import wrap_layout -from awkward._nplikes.numpy import Numpy +from awkward._nplikes.numpy import Numpy, NumpyMetadata +from awkward._typing import TYPE_CHECKING, Any, TypeAlias, TypedDict +if TYPE_CHECKING: + from awkward.contents.content import Content + + +FormatterType: TypeAlias = "Callable[[Any], str]" + + +class FormatterOptions(TypedDict, total=False): + bool: FormatterType + int: FormatterType + timedelta: FormatterType + datetime: FormatterType + float: FormatterType + longfloat: FormatterType + complexfloat: FormatterType + longcomplexfloat: FormatterType + numpystr: FormatterType + object: FormatterType + all: FormatterType + int_kind: FormatterType + float_kind: FormatterType + complex_kind: FormatterType + str_kind: FormatterType + str: FormatterType + bytes: FormatterType + + +np = NumpyMetadata.instance() numpy = Numpy.instance() -def half(integer): +def half(integer: int) -> int: return int(math.ceil(integer / 2)) -def alternate(length): +def alternate(length: int): halfindex = half(length) forward = iter(range(halfindex)) backward = iter(range(length - 1, halfindex - 1, -1)) @@ -42,7 +71,7 @@ def alternate(length): # to form an error string: private reimplementation of ak.Array.__getitem__ -def get_at(data, index): +def get_at(data: Content, index: int): out = data._layout._getitem_at(index) if isinstance(out, ak.contents.NumpyArray): array_param = out.parameter("__array__") @@ -56,7 +85,7 @@ def get_at(data, index): return out -def get_field(data, field): +def get_field(data: Content, field: str): out = data._layout._getitem_field(field) if isinstance(out, ak.contents.NumpyArray): array_param = out.parameter("__array__") @@ -70,7 +99,7 @@ def get_field(data, field): return out -def custom_str(current): +def custom_str(current: Any) -> str | None: if ( issubclass(type(current), ak.highlevel.Record) and type(current).__str__ is not ak.highlevel.Record.__str__ @@ -93,12 +122,14 @@ def custom_str(current): return None -def valuestr_horiz(data, limit_cols): +def valuestr_horiz( + data: Any, limit_cols: int, formatter: Formatter +) -> tuple[int, list[str]]: if isinstance(data, (ak.highlevel.Array, ak.highlevel.Record)) and ( not data.layout.backend.nplike.known_data ): if isinstance(data, ak.highlevel.Array): - return 5, "[...]" + return 5, ["[...]"] original_limit_cols = limit_cols @@ -110,7 +141,7 @@ def valuestr_horiz(data, limit_cols): return 2, front + back elif len(data) == 1: - cols_taken, strs = valuestr_horiz(get_at(data, 0), limit_cols) + cols_taken, strs = valuestr_horiz(get_at(data, 0), limit_cols, formatter) return 2 + cols_taken, front + strs + back else: @@ -120,7 +151,9 @@ def valuestr_horiz(data, limit_cols): current = get_at(data, index) if forward: for_comma = 0 if which == 0 else 2 - cols_taken, strs = valuestr_horiz(current, limit_cols - for_comma) + cols_taken, strs = valuestr_horiz( + current, limit_cols - for_comma, formatter + ) custom = custom_str(current) if custom is not None: @@ -135,7 +168,9 @@ def valuestr_horiz(data, limit_cols): else: break else: - cols_taken, strs = valuestr_horiz(current, limit_cols - 2) + cols_taken, strs = valuestr_horiz( + current, limit_cols - 2, formatter + ) custom = custom_str(current) if custom is not None: @@ -190,7 +225,9 @@ def valuestr_horiz(data, limit_cols): which += 1 target = limit_cols if len(fields) == 1 else half(limit_cols) - cols_taken, strs = valuestr_horiz(get_field(data, key), target) + cols_taken, strs = valuestr_horiz( + get_field(data, key), target, formatter + ) if limit_cols - cols_taken >= 0: front.extend(strs) limit_cols -= cols_taken @@ -217,21 +254,91 @@ def valuestr_horiz(data, limit_cols): return original_limit_cols - limit_cols, front else: - if isinstance(data, (str, bytes)): - out = repr(data) - elif isinstance(data, numbers.Integral): - out = str(data) - elif isinstance(data, numbers.Real): - out = f"{data:.3g}" - elif isinstance(data, numbers.Complex): - out = f"{data.real:.2g}+{data.imag:.2g}j" + out = formatter(data) + return len(out), [out] + + +class Formatter: + def __init__(self, formatters: FormatterOptions | None = None, precision: int = 3): + self._formatters: FormatterOptions = formatters or {} + self._precision: int = precision + self._cache: dict[type, FormatterType] = {} + + def __call__(self, obj: Any) -> str: + try: + impl = self._cache[type(obj)] + except KeyError: + impl = self._find_formatter_impl(type(obj)) + self._cache[type(obj)] = impl + return impl(obj) + + def _format_complex(self, data: complex) -> str: + return f"{data.real:.{self._precision}g}+{data.imag:.{self._precision}g}j" + + def _format_real(self, data: float) -> str: + return f"{data:.{self._precision}g}" + + def _find_formatter_impl(self, cls: type) -> FormatterType: + if issubclass(cls, np.bool_): + try: + return self._formatters["bool"] + except KeyError: + return str + elif issubclass(cls, np.integer): + try: + return self._formatters["int"] + except KeyError: + return self._formatters.get("int_kind", str) + elif issubclass(cls, (np.float64, np.float32)): + try: + return self._formatters["float"] + except KeyError: + return self._formatters.get("float_kind", self._format_real) + elif hasattr(np, "float128") and issubclass(cls, np.float128): + try: + return self._formatters["longfloat"] + except KeyError: + return self._formatters.get("float_kind", self._format_real) + elif issubclass(cls, (np.complex64, np.complex128)): + try: + return self._formatters["complexfloat"] + except KeyError: + return self._formatters.get("complex_kind", self._format_complex) + elif hasattr(np, "complex256") and issubclass(cls, np.complex256): + try: + return self._formatters["longcomplexfloat"] + except KeyError: + return self._formatters.get("complex_kind", self._format_complex) + elif issubclass(cls, np.datetime64): + try: + return self._formatters["datetime"] + except KeyError: + return str + elif issubclass(cls, np.timedelta64): + try: + return self._formatters["timedelta"] + except KeyError: + return str + elif issubclass(cls, str): + try: + return self._formatters["str"] + except KeyError: + return self._formatters.get("str_kind", repr) + elif issubclass(cls, bytes): + try: + return self._formatters["bytes"] + except KeyError: + return self._formatters.get("str_kind", repr) else: - out = str(data) + return str - return len(out), [out] +def valuestr( + data: Any, limit_rows: int, limit_cols: int, formatter: Formatter | None = None +) -> str: + if formatter is None: + formatter = Formatter() -def valuestr(data, limit_rows, limit_cols): if isinstance(data, (ak.highlevel.Array, ak.highlevel.Record)) and ( not data.layout.backend.nplike.known_data ): @@ -240,14 +347,14 @@ def valuestr(data, limit_rows, limit_cols): return "[...]" if limit_rows <= 1: - _, strs = valuestr_horiz(data, limit_cols) + _, strs = valuestr_horiz(data, limit_cols, formatter) return "".join(strs) elif isinstance(data, ak.highlevel.Array): front, back = [], [] which = 0 for forward, index in alternate(len(data)): - _, strs = valuestr_horiz(get_at(data, index), limit_cols - 2) + _, strs = valuestr_horiz(get_at(data, index), limit_cols - 2, formatter) if forward: front.append("".join(strs)) else: @@ -291,7 +398,7 @@ def valuestr(data, limit_rows, limit_cols): else: key_str = key + ": " _, strs = valuestr_horiz( - get_field(data, key), limit_cols - 2 - len(key_str) + get_field(data, key), limit_cols - 2 - len(key_str), formatter ) front.append(key_str + "".join(strs)) diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index 4e0b605c10..a83e5f59c8 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -27,6 +27,7 @@ from awkward._nplikes.numpy_like import NumpyMetadata from awkward._operators import NDArrayOperatorsMixin from awkward._pickle import custom_reduce +from awkward._prettyprint import Formatter from awkward._regularize import is_non_string_like_iterable from awkward._typing import TypeVar @@ -1279,7 +1280,16 @@ def _repr(self, limit_cols): return f"<{pytype}{valuestr} type={typestr}>" - def show(self, limit_rows=20, limit_cols=80, type=False, stream=sys.stdout): + def show( + self, + limit_rows=20, + limit_cols=80, + type=False, + stream=sys.stdout, + *, + formatter=None, + precision=3, + ): """ Args: limit_rows (int): Maximum number of rows (lines) to use in the output. @@ -1289,12 +1299,25 @@ def show(self, limit_rows=20, limit_cols=80, type=False, stream=sys.stdout): stream (object with a ``write(str)`` method or None): Stream to write the output to. If None, return a string instead of writing to a stream. + formatter (Mapping or None): Mapping of types/type-classes to string formatters. + If None, use the default formatter. + Display the contents of the array within `limit_rows` and `limit_cols`, using ellipsis (`...`) for hidden nested data. + + The `formatter` argument controls the formatting of individual values, c.f. + https://numpy.org/doc/stable/reference/generated/numpy.set_printoptions.html + As Awkward Array does not implement strings as a NumPy dtype, the `numpystr` + key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting + string values, falling back upon `"str_kind"`. """ import awkward._prettyprint - valuestr = awkward._prettyprint.valuestr(self, limit_rows, limit_cols) + formatter_impl = Formatter(formatter, precision=precision) + + valuestr = awkward._prettyprint.valuestr( + self, limit_rows, limit_cols, formatter=formatter_impl + ) if type: tmp = io.StringIO() self.type.show(stream=tmp) @@ -2071,7 +2094,16 @@ def _repr(self, limit_cols): return f"<{pytype}{valuestr} type={typestr}>" - def show(self, limit_rows=20, limit_cols=80, type=False, stream=sys.stdout): + def show( + self, + limit_rows=20, + limit_cols=80, + type=False, + stream=sys.stdout, + *, + formatter=None, + precision=3, + ): """ Args: limit_rows (int): Maximum number of rows (lines) to use in the output. @@ -2080,13 +2112,24 @@ def show(self, limit_rows=20, limit_cols=80, type=False, stream=sys.stdout): of rows/lines limit.) stream (object with a ``write(str)`` method or None): Stream to write the output to. If None, return a string instead of writing to a stream. + formatter (Mapping or None): Mapping of types/type-classes to string formatters. + If None, use the default formatter. - Display the contents of the record within `limit_rows` and `limit_cols`, using + Display the contents of the array within `limit_rows` and `limit_cols`, using ellipsis (`...`) for hidden nested data. + + The `formatter` argument controls the formatting of individual values, c.f. + https://numpy.org/doc/stable/reference/generated/numpy.set_printoptions.html + As Awkward Array does not implement strings as a NumPy dtype, the `numpystr` + key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting + string values, falling back upon `"str_kind"`. """ import awkward._prettyprint - valuestr = awkward._prettyprint.valuestr(self, limit_rows, limit_cols) + formatter_impl = Formatter(formatter, precision=precision) + valuestr = awkward._prettyprint.valuestr( + self, limit_rows, limit_cols, formatter=formatter_impl + ) if type: tmp = io.StringIO() self.type.show(stream=tmp) @@ -2465,7 +2508,16 @@ def _repr(self, limit_cols): return f"" - def show(self, limit_rows=20, limit_cols=80, type=False, stream=sys.stdout): + def show( + self, + limit_rows=20, + limit_cols=80, + type=False, + stream=sys.stdout, + *, + formatter=None, + precision=3, + ): """ Args: limit_rows (int): Maximum number of rows (lines) to use in the output. @@ -2474,15 +2526,28 @@ def show(self, limit_rows=20, limit_cols=80, type=False, stream=sys.stdout): of rows/lines limit.) stream (object with a ``write(str)`` method or None): Stream to write the output to. If None, return a string instead of writing to a stream. + formatter (Mapping or None): Mapping of types/type-classes to string formatters. + If None, use the default formatter. Display the contents of the array within `limit_rows` and `limit_cols`, using ellipsis (`...`) for hidden nested data. + The `formatter` argument controls the formatting of individual values, c.f. + https://numpy.org/doc/stable/reference/generated/numpy.set_printoptions.html + As Awkward Array does not implement strings as a NumPy dtype, the `numpystr` + key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting + string values, falling back upon `"str_kind"`. + This method takes a snapshot of the data and calls show on it, and a snapshot copies data. """ + formatter_impl = Formatter(formatter, precision=precision) return self.snapshot().show( - limit_rows=limit_rows, limit_cols=limit_cols, type=type, stream=stream + limit_rows=limit_rows, + limit_cols=limit_cols, + type=type, + stream=stream, + formatter=formatter_impl, ) def __array__(self, dtype=None): diff --git a/tests/test_2803_string_formatter.py b/tests/test_2803_string_formatter.py new file mode 100644 index 0000000000..d860bd05fd --- /dev/null +++ b/tests/test_2803_string_formatter.py @@ -0,0 +1,143 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import io +from datetime import datetime, timedelta + +import numpy as np +import pytest + +import awkward as ak + + +def test_precision(): + stream = io.StringIO() + ak.Array([1.0, 2.3456789]).show(stream=stream, precision=1) + assert stream.getvalue() == "[1,\n 2]\n" + + stream.seek(0) + ak.Array([1.0, 2.3456789]).show(stream=stream, precision=7) + assert stream.getvalue() == "[1,\n 2.345679]\n" + + +def test_float(): + stream = io.StringIO() + ak.Array([1.0, 2.3456789]).show( + stream=stream, formatter={"float": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + stream.seek(0) + ak.Array([1.0, 2.3456789]).show( + stream=stream, formatter={"float_kind": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + +@pytest.mark.skipif( + not hasattr(np, "float128"), reason="Only for 128-float supporting platforms" +) +def test_longfloat(): + stream = io.StringIO() + ak.values_astype([1.0, 2.3456789], "float128").show( + stream=stream, formatter={"longfloat": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + stream.seek(0) + ak.values_astype([1.0, 2.3456789], "float128").show( + stream=stream, formatter={"float_kind": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + +def test_complex(): + stream = io.StringIO() + ak.Array([1.0j, 2.3456789j + 2]).show( + stream=stream, formatter={"complexfloat": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + stream.seek(0) + ak.Array([1.0j, 2.3456789j + 2]).show( + stream=stream, formatter={"complex_kind": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + +@pytest.mark.skipif( + not hasattr(np, "complex256"), reason="Only for 256-complex supporting platforms" +) +def test_longcomplex(): + stream = io.StringIO() + ak.values_astype([1.0j, 2.3456789j + 2], "complex256").show( + stream=stream, formatter={"longcomplexfloat": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + stream.seek(0) + ak.values_astype([1.0j, 2.3456789j + 2], "complex256").show( + stream=stream, formatter={"complex_kind": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + +def test_int(): + stream = io.StringIO() + ak.Array([1, 2]).show(stream=stream, formatter={"int": "".format}) + assert stream.getvalue() == "[,\n ]\n" + + stream.seek(0) + ak.Array([1, 2]).show(stream=stream, formatter={"int_kind": "".format}) + assert stream.getvalue() == "[,\n ]\n" + + +def test_bool(): + stream = io.StringIO() + ak.Array([True, False]).show(stream=stream, formatter={"bool": "".format}) + assert stream.getvalue() == "[,\n ]\n" + + +def test_str(): + stream = io.StringIO() + ak.Array(["hello", "world"]).show( + stream=stream, formatter={"str": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + stream.seek(0) + ak.Array(["hello", "world"]).show( + stream=stream, formatter={"str_kind": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + +def test_bytes(): + stream = io.StringIO() + ak.Array([b"hello", b"world"]).show( + stream=stream, formatter={"bytes": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + stream.seek(0) + ak.Array([b"hello", b"world"]).show( + stream=stream, formatter={"str_kind": "".format} + ) + assert stream.getvalue() == "[,\n ]\n" + + +def test_datetime64(): + stream = io.StringIO() + ak.Array([datetime(year=2023, month=1, day=1, hour=12, minute=1, second=30)]).show( + stream=stream, formatter={"datetime": "
".format} + ) + assert stream.getvalue() == "[
]\n" + + +def test_timedelta64(): + stream = io.StringIO() + ak.Array([timedelta(days=1, hours=12, minutes=1, seconds=30)]).show( + stream=stream, formatter={"datetime": "".format} + ) + assert stream.getvalue() == "[129690000000 microseconds]\n"