Skip to content

Commit

Permalink
feat: add string formatter to ak.Array.show (#2803)
Browse files Browse the repository at this point in the history
* wip: add hints

* feat: add `formatter` and `precision` options to string formatting

* test: ensure all options work!
  • Loading branch information
agoose77 authored Nov 7, 2023
1 parent a4ebc3b commit 9ee586d
Show file tree
Hide file tree
Showing 3 changed files with 349 additions and 34 deletions.
161 changes: 134 additions & 27 deletions src/awkward/_prettyprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,50 @@
from __future__ import annotations

import math
import numbers
import re
from collections.abc import Callable

import awkward as ak
from awkward._layout import wrap_layout
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpy import Numpy, NumpyMetadata
from awkward._typing import TYPE_CHECKING, Any, TypeAlias, TypedDict

if TYPE_CHECKING:
from awkward.contents.content import Content


FormatterType: TypeAlias = "Callable[[Any], str]"


class FormatterOptions(TypedDict, total=False):
bool: FormatterType
int: FormatterType
timedelta: FormatterType
datetime: FormatterType
float: FormatterType
longfloat: FormatterType
complexfloat: FormatterType
longcomplexfloat: FormatterType
numpystr: FormatterType
object: FormatterType
all: FormatterType
int_kind: FormatterType
float_kind: FormatterType
complex_kind: FormatterType
str_kind: FormatterType
str: FormatterType
bytes: FormatterType


np = NumpyMetadata.instance()
numpy = Numpy.instance()


def half(integer):
def half(integer: int) -> int:
return int(math.ceil(integer / 2))


def alternate(length):
def alternate(length: int):
halfindex = half(length)
forward = iter(range(halfindex))
backward = iter(range(length - 1, halfindex - 1, -1))
Expand All @@ -42,7 +71,7 @@ def alternate(length):
# to form an error string: private reimplementation of ak.Array.__getitem__


def get_at(data, index):
def get_at(data: Content, index: int):
out = data._layout._getitem_at(index)
if isinstance(out, ak.contents.NumpyArray):
array_param = out.parameter("__array__")
Expand All @@ -56,7 +85,7 @@ def get_at(data, index):
return out


def get_field(data, field):
def get_field(data: Content, field: str):
out = data._layout._getitem_field(field)
if isinstance(out, ak.contents.NumpyArray):
array_param = out.parameter("__array__")
Expand All @@ -70,7 +99,7 @@ def get_field(data, field):
return out


def custom_str(current):
def custom_str(current: Any) -> str | None:
if (
issubclass(type(current), ak.highlevel.Record)
and type(current).__str__ is not ak.highlevel.Record.__str__
Expand All @@ -93,12 +122,14 @@ def custom_str(current):
return None


def valuestr_horiz(data, limit_cols):
def valuestr_horiz(
data: Any, limit_cols: int, formatter: Formatter
) -> tuple[int, list[str]]:
if isinstance(data, (ak.highlevel.Array, ak.highlevel.Record)) and (
not data.layout.backend.nplike.known_data
):
if isinstance(data, ak.highlevel.Array):
return 5, "[...]"
return 5, ["[...]"]

original_limit_cols = limit_cols

Expand All @@ -110,7 +141,7 @@ def valuestr_horiz(data, limit_cols):
return 2, front + back

elif len(data) == 1:
cols_taken, strs = valuestr_horiz(get_at(data, 0), limit_cols)
cols_taken, strs = valuestr_horiz(get_at(data, 0), limit_cols, formatter)
return 2 + cols_taken, front + strs + back

else:
Expand All @@ -120,7 +151,9 @@ def valuestr_horiz(data, limit_cols):
current = get_at(data, index)
if forward:
for_comma = 0 if which == 0 else 2
cols_taken, strs = valuestr_horiz(current, limit_cols - for_comma)
cols_taken, strs = valuestr_horiz(
current, limit_cols - for_comma, formatter
)

custom = custom_str(current)
if custom is not None:
Expand All @@ -135,7 +168,9 @@ def valuestr_horiz(data, limit_cols):
else:
break
else:
cols_taken, strs = valuestr_horiz(current, limit_cols - 2)
cols_taken, strs = valuestr_horiz(
current, limit_cols - 2, formatter
)

custom = custom_str(current)
if custom is not None:
Expand Down Expand Up @@ -190,7 +225,9 @@ def valuestr_horiz(data, limit_cols):
which += 1

target = limit_cols if len(fields) == 1 else half(limit_cols)
cols_taken, strs = valuestr_horiz(get_field(data, key), target)
cols_taken, strs = valuestr_horiz(
get_field(data, key), target, formatter
)
if limit_cols - cols_taken >= 0:
front.extend(strs)
limit_cols -= cols_taken
Expand All @@ -217,21 +254,91 @@ def valuestr_horiz(data, limit_cols):
return original_limit_cols - limit_cols, front

else:
if isinstance(data, (str, bytes)):
out = repr(data)
elif isinstance(data, numbers.Integral):
out = str(data)
elif isinstance(data, numbers.Real):
out = f"{data:.3g}"
elif isinstance(data, numbers.Complex):
out = f"{data.real:.2g}+{data.imag:.2g}j"
out = formatter(data)
return len(out), [out]


class Formatter:
def __init__(self, formatters: FormatterOptions | None = None, precision: int = 3):
self._formatters: FormatterOptions = formatters or {}
self._precision: int = precision
self._cache: dict[type, FormatterType] = {}

def __call__(self, obj: Any) -> str:
try:
impl = self._cache[type(obj)]
except KeyError:
impl = self._find_formatter_impl(type(obj))
self._cache[type(obj)] = impl
return impl(obj)

def _format_complex(self, data: complex) -> str:
return f"{data.real:.{self._precision}g}+{data.imag:.{self._precision}g}j"

def _format_real(self, data: float) -> str:
return f"{data:.{self._precision}g}"

def _find_formatter_impl(self, cls: type) -> FormatterType:
if issubclass(cls, np.bool_):
try:
return self._formatters["bool"]
except KeyError:
return str
elif issubclass(cls, np.integer):
try:
return self._formatters["int"]
except KeyError:
return self._formatters.get("int_kind", str)
elif issubclass(cls, (np.float64, np.float32)):
try:
return self._formatters["float"]
except KeyError:
return self._formatters.get("float_kind", self._format_real)
elif hasattr(np, "float128") and issubclass(cls, np.float128):
try:
return self._formatters["longfloat"]
except KeyError:
return self._formatters.get("float_kind", self._format_real)
elif issubclass(cls, (np.complex64, np.complex128)):
try:
return self._formatters["complexfloat"]
except KeyError:
return self._formatters.get("complex_kind", self._format_complex)
elif hasattr(np, "complex256") and issubclass(cls, np.complex256):
try:
return self._formatters["longcomplexfloat"]
except KeyError:
return self._formatters.get("complex_kind", self._format_complex)
elif issubclass(cls, np.datetime64):
try:
return self._formatters["datetime"]
except KeyError:
return str
elif issubclass(cls, np.timedelta64):
try:
return self._formatters["timedelta"]
except KeyError:
return str
elif issubclass(cls, str):
try:
return self._formatters["str"]
except KeyError:
return self._formatters.get("str_kind", repr)
elif issubclass(cls, bytes):
try:
return self._formatters["bytes"]
except KeyError:
return self._formatters.get("str_kind", repr)
else:
out = str(data)
return str

return len(out), [out]

def valuestr(
data: Any, limit_rows: int, limit_cols: int, formatter: Formatter | None = None
) -> str:
if formatter is None:
formatter = Formatter()

def valuestr(data, limit_rows, limit_cols):
if isinstance(data, (ak.highlevel.Array, ak.highlevel.Record)) and (
not data.layout.backend.nplike.known_data
):
Expand All @@ -240,14 +347,14 @@ def valuestr(data, limit_rows, limit_cols):
return "[...]"

if limit_rows <= 1:
_, strs = valuestr_horiz(data, limit_cols)
_, strs = valuestr_horiz(data, limit_cols, formatter)
return "".join(strs)

elif isinstance(data, ak.highlevel.Array):
front, back = [], []
which = 0
for forward, index in alternate(len(data)):
_, strs = valuestr_horiz(get_at(data, index), limit_cols - 2)
_, strs = valuestr_horiz(get_at(data, index), limit_cols - 2, formatter)
if forward:
front.append("".join(strs))
else:
Expand Down Expand Up @@ -291,7 +398,7 @@ def valuestr(data, limit_rows, limit_cols):
else:
key_str = key + ": "
_, strs = valuestr_horiz(
get_field(data, key), limit_cols - 2 - len(key_str)
get_field(data, key), limit_cols - 2 - len(key_str), formatter
)
front.append(key_str + "".join(strs))

Expand Down
Loading

0 comments on commit 9ee586d

Please sign in to comment.